笔记本单元测试

可以使用单元测试来帮助提高笔记本代码的质量和一致性。 单元测试是一种尽早且经常测试自包含代码单元(例如函数)的方法。 这有助于更快地发现代码问题,更快地发现关于代码的错误假设,并简化整体编码工作。

本文介绍如何使用函数进行基本的单元测试。 单元测试类和接口等高级概念,以及存根、模拟和测试工具的使用,虽然在笔记本的单元测试时也受支持,但不在本文的讨论范围之内。 本文也不会介绍其他种类的测试方法,例如集成测试、系统测试、验收测试,或者性能测试或可用性测试等非功能测试方法。

本文演示以下内容:

  • 如何组织函数及其单元测试。
  • 如何使用 Python、R、Scala 编写函数,以及使用 SQL 编写用户定义的函数,这些函数设计精良,可以进行单元测试。
  • 如何从 Python、R、Scala 和 SQL 笔记本中调用这些函数。
  • 如何通过适用于 Python 的常用测试框架 pytest、适用于 R 的 testthat 和适用于 Scala 的 ScalaTest 使用 Python、R 和 Scala 编写单元测试。 以及如何编写对 SQL 用户定义函数 (SQL UDF) 进行单元测试的 SQL。
  • 如何从 Python、R、Scala 和 SQL 笔记本运行这些单元测试。

组织函数和单元测试

使用笔记本组织函数及其单元测试有一些常见方法。 每种方法都有其优势和挑战。

对于 Python、R 和 Scala 笔记本,常见的方法包括:

对于 Python 和 R 笔记本,Databricks 建议在笔记本外部存储函数及其单元测试。 对于 Scala 笔记本,Databricks 建议将函数包含在一个笔记本中,将其单元测试包含在一个单独的笔记本中。

对于 SQL 笔记本,Databricks 建议将函数作为 SQL 用户定义函数 (SQL UDF) 存储在架构(也称为数据库)中。 然后,可以从 SQL 笔记本调用这些 SQL UDF 及其单元测试。

编写函数

本节描述了一组简单的示例函数,它们确定以下内容:

  • 表是否存在于数据库中。
  • 某列是否存在于该表中。
  • 对于该列中的某个值,该列中存在多少行。

这些函数比较简单,因此你可以专注于本文中详述的单元测试,而不必将注意力放在函数本身上。

为了获得最好的单元测试结果,一个函数应返回单一的可预测结果并且是单一的数据类型。 例如,要检查是否存在某些内容,该函数应返回布尔值 true 或 false。 要返回存在的行数,该函数应返回一个非负整数。 在第一个示例中,如果某内容不存在,则它应返回 false,如果它存在,应返回该内容本身。 同样,对于第二个示例,它应返回存在的行数,如果不存在行,则应返回 false。

可以如下所示在 Python、R、Scala 或 SQL 中将这些函数添加到现有的 Azure Databricks 工作区。

Python

以下代码假设你已设置 Databricks Git 文件夹 (Repos)添加了存储库,并在 Azure Databricks 工作区中打开了该存储库。

在存储库中创建一个名为 myfunctions.py 的文件,并将以下内容添加到该文件中。 本文中的其他示例要求此文件命名为 myfunctions.py。 可以为自己的文件使用不同的名称。

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
                    .appName('integrity-tests') \
                    .getOrCreate()

# Does the specified table exist in the specified database?
def tableExists(tableName, dbName):
  return spark.catalog.tableExists(f"{dbName}.{tableName}")

# Does the specified column exist in the given DataFrame?
def columnExists(dataFrame, columnName):
  if columnName in dataFrame.columns:
    return True
  else:
    return False

# How many rows are there for the specified value in the specified column
# in the given DataFrame?
def numRowsInColumnForValue(dataFrame, columnName, columnValue):
  df = dataFrame.filter(col(columnName) == columnValue)

  return df.count()

R

以下代码假设你已设置 Databricks Git 文件夹 (Repos)添加了存储库,并在 Azure Databricks 工作区中打开了该存储库。

在存储库中创建一个名为 myfunctions.r 的文件,并将以下内容添加到该文件中。 本文中的其他示例要求此文件命名为 myfunctions.r。 可以为自己的文件使用不同的名称。

library(SparkR)

# Does the specified table exist in the specified database?
table_exists <- function(table_name, db_name) {
  tableExists(paste(db_name, ".", table_name, sep = ""))
}

# Does the specified column exist in the given DataFrame?
column_exists <- function(dataframe, column_name) {
  column_name %in% colnames(dataframe)
}

# How many rows are there for the specified value in the specified column
# in the given DataFrame?
num_rows_in_column_for_value <- function(dataframe, column_name, column_value) {
  df = filter(dataframe, dataframe[[column_name]] == column_value)

  count(df)
}

Scala

使用以下内容创建一个名为 myfunctionsScala 笔记本。 本文中的其他示例要求将此笔记本命名为 myfunctions。 可以为自己的笔记本使用不同的名称。

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col

// Does the specified table exist in the specified database?
def tableExists(tableName: String, dbName: String) : Boolean = {
  return spark.catalog.tableExists(dbName + "." + tableName)
}

// Does the specified column exist in the given DataFrame?
def columnExists(dataFrame: DataFrame, columnName: String) : Boolean = {
  val nameOfColumn = null

  for(nameOfColumn <- dataFrame.columns) {
    if (nameOfColumn == columnName) {
      return true
    }
  }

  return false
}

// How many rows are there for the specified value in the specified column
// in the given DataFrame?
def numRowsInColumnForValue(dataFrame: DataFrame, columnName: String, columnValue: String) : Long = {
  val df = dataFrame.filter(col(columnName) === columnValue)

  return df.count()
}

SQL

以下代码假定在名为 default 的架构中拥有第三方示例数据集 diamonds,该架构位于名为 main 的目录中,可从 Azure Databricks 工作区访问。 如果要使用的目录或架构具有不同的名称,请更改以下一个或两个 USE 语句使其一致。

创建一个 SQL 笔记本并将以下内容添加到此新笔记本中。 然后将笔记本附加到群集并运行笔记本以将以下 SQL UDF 添加到指定的目录和架构。

注意

SQL UDF table_existscolumn_exists 仅适用于 Unity Catalog。 Unity Catalog 的 SQL UDF 支持以公共预览版提供。

USE CATALOG main;
USE SCHEMA default;

CREATE OR REPLACE FUNCTION table_exists(catalog_name STRING,
                                        db_name      STRING,
                                        table_name   STRING)
  RETURNS BOOLEAN
  RETURN if(
    (SELECT count(*) FROM system.information_schema.tables
     WHERE table_catalog = table_exists.catalog_name
       AND table_schema  = table_exists.db_name
       AND table_name    = table_exists.table_name) > 0,
    true,
    false
  );

CREATE OR REPLACE FUNCTION column_exists(catalog_name STRING,
                                         db_name      STRING,
                                         table_name   STRING,
                                         column_name  STRING)
  RETURNS BOOLEAN
  RETURN if(
    (SELECT count(*) FROM system.information_schema.columns
     WHERE table_catalog = column_exists.catalog_name
       AND table_schema  = column_exists.db_name
       AND table_name    = column_exists.table_name
       AND column_name   = column_exists.column_name) > 0,
    true,
    false
  );

CREATE OR REPLACE FUNCTION num_rows_for_clarity_in_diamonds(clarity_value STRING)
  RETURNS BIGINT
  RETURN SELECT count(*)
         FROM main.default.diamonds
         WHERE clarity = clarity_value

调用函数

本节介绍调用上述函数的代码。 例如,可以使用这些函数来计算表中指定值存在于指定列中的行数。 但是,在继续之前,需要检查该表是否实际存在,以及该列是否实际存在于该表中。 以下代码检查这些条件。

如果将上一部分中的函数添加到 Azure Databricks 工作区,则可以从工作区调用这些函数,如下所示。

Python

在存储库中与上述 myfunctions.py 文件相同的文件夹中创建一个 Python 笔记本,并将以下内容添加到笔记本中。 根据需要更改表名称、架构(数据库)名称、列名称和列值的变量值。 然后将笔记本附加到群集并运行笔记本以查看结果。

from myfunctions import *

tableName   = "diamonds"
dbName      = "default"
columnName  = "clarity"
columnValue = "VVS2"

# If the table exists in the specified database...
if tableExists(tableName, dbName):

  df = spark.sql(f"SELECT * FROM {dbName}.{tableName}")

  # And the specified column exists in that table...
  if columnExists(df, columnName):
    # Then report the number of rows for the specified value in that column.
    numRows = numRowsInColumnForValue(df, columnName, columnValue)

    print(f"There are {numRows} rows in '{tableName}' where '{columnName}' equals '{columnValue}'.")
  else:
    print(f"Column '{columnName}' does not exist in table '{tableName}' in schema (database) '{dbName}'.")
else:
  print(f"Table '{tableName}' does not exist in schema (database) '{dbName}'.") 

R

在存储库中与上述 myfunctions.r 文件相同的文件夹中创建一个 R 笔记本,并将以下内容添加到笔记本中。 根据需要更改表名称、架构(数据库)名称、列名称和列值的变量值。 然后将笔记本附加到群集并运行笔记本以查看结果。

library(SparkR)
source("myfunctions.r")

table_name   <- "diamonds"
db_name      <- "default"
column_name  <- "clarity"
column_value <- "VVS2"

# If the table exists in the specified database...
if (table_exists(table_name, db_name)) {

  df = sql(paste("SELECT * FROM ", db_name, ".", table_name, sep = ""))

  # And the specified column exists in that table...
  if (column_exists(df, column_name)) {
    # Then report the number of rows for the specified value in that column.
    num_rows = num_rows_in_column_for_value(df, column_name, column_value)

    print(paste("There are ", num_rows, " rows in table '", table_name, "' where '", column_name, "' equals '", column_value, "'.", sep = "")) 
  } else {
    print(paste("Column '", column_name, "' does not exist in table '", table_name, "' in schema (database) '", db_name, "'.", sep = ""))
  }

} else {
  print(paste("Table '", table_name, "' does not exist in schema (database) '", db_name, "'.", sep = ""))
}

Scala

在与之前的 myfunctions Scala 笔记本相同的文件夹中创建另一个 Scala 笔记本,并将以下内容添加到此新笔记本中。

在此新笔记本的第一个单元格中,添加以下用于调用 %run magic 的代码。 此 magic 使 myfunctions 笔记本的内容可用于新笔记本。

%run ./myfunctions

在此新笔记本的第二个单元格中,添加以下代码。 根据需要更改表名称、架构(数据库)名称、列名称和列值的变量值。 然后将笔记本附加到群集并运行笔记本以查看结果。

val tableName   = "diamonds"
val dbName      = "default"
val columnName  = "clarity"
val columnValue = "VVS2"

// If the table exists in the specified database...
if (tableExists(tableName, dbName)) {

  val df = spark.sql("SELECT * FROM " + dbName + "." + tableName)

  // And the specified column exists in that table...
  if (columnExists(df, columnName)) {
    // Then report the number of rows for the specified value in that column.
    val numRows = numRowsInColumnForValue(df, columnName, columnValue)

    println("There are " + numRows + " rows in '" + tableName + "' where '" + columnName + "' equals '" + columnValue + "'.")
  } else {
    println("Column '" + columnName + "' does not exist in table '" + tableName + "' in database '" + dbName + "'.")
  }

} else {
  println("Table '" + tableName + "' does not exist in database '" + dbName + "'.")
}

SQL

将以下代码添加到之前笔记本中的新单元格或单独笔记本中的单元格。 如有必要,更改架构或目录名称以与你的名称一致,然后运行此单元格以查看结果。

SELECT CASE
-- If the table exists in the specified catalog and schema...
WHEN
  table_exists("main", "default", "diamonds")
THEN
  -- And the specified column exists in that table...
  (SELECT CASE
   WHEN
     column_exists("main", "default", "diamonds", "clarity")
   THEN
     -- Then report the number of rows for the specified value in that column.
     printf("There are %d rows in table 'main.default.diamonds' where 'clarity' equals 'VVS2'.",
            num_rows_for_clarity_in_diamonds("VVS2"))
   ELSE
     printf("Column 'clarity' does not exist in table 'main.default.diamonds'.")
   END)
ELSE
  printf("Table 'main.default.diamonds' does not exist.")
END

写入单元测试

本节介绍测试本文开头所述的每个函数的代码。 如果将来对函数进行任何更改,可以使用单元测试来确定这些函数是否正常工作。

如果将本文开头的函数添加到 Azure Databricks 工作区,则可以将这些函数的单元测试添加到工作区,如下所示。

Python

在存储库中与前面的 myfunctions.py 文件相同的文件夹中创建另一个名为 test_myfunctions.py 的文件,并将以下内容添加到文件中。 默认情况下,pytest 会查找名称以 test_ 开头(或以 _test 结尾)的 .py 文件进行测试。 同样,默认情况下,pytest 在这些文件中查找名称以 test_ 开头的函数进行测试。

一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是创建尽可能接近生产数据的虚假数据。 以下代码示例创建要对其运行单元测试的虚假数据。

import pytest
import pyspark
from myfunctions import *
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType

tableName    = "diamonds"
dbName       = "default"
columnName   = "clarity"
columnValue  = "SI2"

# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
                    .appName('integrity-tests') \
                    .getOrCreate()

# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema = StructType([ \
  StructField("_c0",     IntegerType(), True), \
  StructField("carat",   FloatType(),   True), \
  StructField("cut",     StringType(),  True), \
  StructField("color",   StringType(),  True), \
  StructField("clarity", StringType(),  True), \
  StructField("depth",   FloatType(),   True), \
  StructField("table",   IntegerType(), True), \
  StructField("price",   IntegerType(), True), \
  StructField("x",       FloatType(),   True), \
  StructField("y",       FloatType(),   True), \
  StructField("z",       FloatType(),   True), \
])

data = [ (1, 0.23, "Ideal",   "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43 ), \
         (2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31 ) ]

df = spark.createDataFrame(data, schema)

# Does the table exist?
def test_tableExists():
  assert tableExists(tableName, dbName) is True

# Does the column exist?
def test_columnExists():
  assert columnExists(df, columnName) is True

# Is there at least one row for the value in the specified column?
def test_numRowsInColumnForValue():
  assert numRowsInColumnForValue(df, columnName, columnValue) > 0

R

在存储库中与前面的 myfunctions.r 文件相同的文件夹中创建另一个名为 test_myfunctions.r 的文件,并将以下内容添加到文件中。 默认情况下,testthat 会查找名称以 test 开头的 .r 文件进行测试。

一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是创建尽可能接近生产数据的虚假数据。 以下代码示例创建要对其运行单元测试的虚假数据。

library(testthat)
source("myfunctions.r")

table_name   <- "diamonds"
db_name      <- "default"
column_name  <- "clarity"
column_value <- "SI2"

# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema <- structType(
  structField("_c0",     "integer"),
  structField("carat",   "float"),
  structField("cut",     "string"),
  structField("color",   "string"),
  structField("clarity", "string"),
  structField("depth",   "float"),
  structField("table",   "integer"),
  structField("price",   "integer"),
  structField("x",       "float"),
  structField("y",       "float"),
  structField("z",       "float"))

data <- list(list(as.integer(1), 0.23, "Ideal",   "E", "SI2", 61.5, as.integer(55), as.integer(326), 3.95, 3.98, 2.43),
             list(as.integer(2), 0.21, "Premium", "E", "SI1", 59.8, as.integer(61), as.integer(326), 3.89, 3.84, 2.31))

df <- createDataFrame(data, schema)

# Does the table exist?
test_that ("The table exists.", {
  expect_true(table_exists(table_name, db_name))
})

# Does the column exist?
test_that ("The column exists in the table.", {
  expect_true(column_exists(df, column_name))
})

# Is there at least one row for the value in the specified column?
test_that ("There is at least one row in the query result.", {
  expect_true(num_rows_in_column_for_value(df, column_name, column_value) > 0)
})

Scala

在与之前的 myfunctions Scala 笔记本相同的文件夹中创建另一个 Scala 笔记本,并将以下内容添加到此新笔记本中。

在新笔记本的第一个单元格中,添加以下用于调用 %run magic 的代码。 此 magic 使 myfunctions 笔记本的内容可用于新笔记本。

%run ./myfunctions

在第二个单元格中,添加以下代码。 此代码定义单元测试并指定如何运行它们。

一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是创建尽可能接近生产数据的虚假数据。 以下代码示例创建要对其运行单元测试的虚假数据。

import org.scalatest._
import org.apache.spark.sql.types.{StructType, StructField, IntegerType, FloatType, StringType}
import scala.collection.JavaConverters._

class DataTests extends AsyncFunSuite {

  val tableName   = "diamonds"
  val dbName      = "default"
  val columnName  = "clarity"
  val columnValue = "SI2"

  // Create fake data for the unit tests to run against.
  // In general, it is a best practice to not run unit tests
  // against functions that work with data in production.
  val schema = StructType(Array(
                 StructField("_c0",     IntegerType),
                 StructField("carat",   FloatType),
                 StructField("cut",     StringType),
                 StructField("color",   StringType),
                 StructField("clarity", StringType),
                 StructField("depth",   FloatType),
                 StructField("table",   IntegerType),
                 StructField("price",   IntegerType),
                 StructField("x",       FloatType),
                 StructField("y",       FloatType),
                 StructField("z",       FloatType)
               ))

  val data = Seq(
                  Row(1, 0.23, "Ideal",   "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43),
                  Row(2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31)
                ).asJava

  val df = spark.createDataFrame(data, schema)

  // Does the table exist?
  test("The table exists") {
    assert(tableExists(tableName, dbName) == true)
  }

  // Does the column exist?
  test("The column exists") {
    assert(columnExists(df, columnName) == true)
  }

  // Is there at least one row for the value in the specified column?
  test("There is at least one matching row") {
    assert(numRowsInColumnForValue(df, columnName, columnValue) > 0)
  }
}

nocolor.nodurations.nostacks.stats.run(new DataTests)

注意

此代码示例使用 ScalaTest 中的 FunSuite 测试样式。 有关其他可用的测试样式,请参阅为项目选择测试样式

SQL

在添加单元测试之前请注意,在一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是针对视图而不是表运行单元测试。

若要创建视图,可以从前一笔记本或单独的笔记本中的新单元格调用 CREATE VIEW 命令。 以下示例假设名为 main 的目录中名为 default 的架构(数据库)中有一个名为 diamonds 的现有表。 根据需要更改这些名称以便与你自己的名称匹配,然后仅运行该单元格。

USE CATALOG main;
USE SCHEMA default;

CREATE VIEW view_diamonds AS
SELECT * FROM diamonds;

创建视图后,将以下每个 SELECT 语句添加到前一笔记本或单独的笔记本中该语句原本所在的新单元格中。 根据需要更改名称以便与你自己的名称匹配。

SELECT if(table_exists("main", "default", "view_diamonds"),
          printf("PASS: The table 'main.default.view_diamonds' exists."),
          printf("FAIL: The table 'main.default.view_diamonds' does not exist."));

SELECT if(column_exists("main", "default", "view_diamonds", "clarity"),
          printf("PASS: The column 'clarity' exists in the table 'main.default.view_diamonds'."),
          printf("FAIL: The column 'clarity' does not exists in the table 'main.default.view_diamonds'."));

SELECT if(num_rows_for_clarity_in_diamonds("VVS2") > 0,
          printf("PASS: The table 'main.default.view_diamonds' has at least one row where the column 'clarity' equals 'VVS2'."),
          printf("FAIL: The table 'main.default.view_diamonds' does not have at least one row where the column 'clarity' equals 'VVS2'."));

运行单元测试

本节介绍如何运行在上一节中编码的单元测试。 运行单元测试时,会获得结果,其中显示哪些单元测试通过和失败。

如果将上一部分中的单元测试添加到了 Azure Databricks 工作区,则可以从工作区运行这些单元测试。 可以手动按计划运行这些单元测试。

Python

在存储库中与上述 test_myfunctions.py 文件相同的文件夹中创建 Python 笔记本,并添加以下内容。

在新笔记本的第一个单元格中添加以下代码,然后运行该单元格以调用 %pip magic。 此 magic 安装 pytest

%pip install pytest

在第二个单元格中,添加以下代码,然后运行该单元格。 结果显示哪些单元测试通过和失败。

import pytest
import sys

# Skip writing pyc files on a readonly filesystem.
sys.dont_write_bytecode = True

# Run pytest.
retcode = pytest.main([".", "-v", "-p", "no:cacheprovider"])

# Fail the cell execution if there are any test failures.
assert retcode == 0, "The pytest invocation failed. See the log for details."

R

在存储库中与上述 test_myfunctions.r 文件相同的文件夹中创建一个 R 笔记本,并添加以下内容。

在第一个单元格中添加以下代码,然后运行该单元格以调用 install.packages 函数。 此函数安装 testthat

install.packages("testthat")

在第二个单元格中,添加以下代码,然后运行该单元格。 结果显示哪些单元测试通过和失败。

library(testthat)
source("myfunctions.r")

test_dir(".", reporter = "tap")

Scala

运行上一节笔记本中的第一个和第二个单元格。 结果显示哪些单元测试通过和失败。

SQL

运行上一节笔记本中三个单元格中的每一个。 结果显示每个单元测试是通过还是失败。

如果在运行单元测试后不再需要该视图,可以删除该视图。 若要删除此视图,可将以下代码添加到前面某个笔记本中的新单元格中,然后仅运行该单元格。

DROP VIEW view_diamonds;

提示

可以在群集的驱动程序日志中查看笔记本运行结果(包括单元测试结果)。 还可为群集的日志传送指定一个位置。

可以设置持续集成和持续交付或部署 (CI/CD) 系统,例如 GitHub Actions,以便在代码发生更改时自动运行单元测试。 有关示例,请参阅笔记本的软件工程最佳做法中的 GitHub Actions 介绍。

其他资源

pytest

testthat

ScalaTest

SQL