笔记本单元测试
可以使用单元测试来帮助提高笔记本代码的质量和一致性。 单元测试是一种尽早且经常测试自包含代码单元(例如函数)的方法。 这有助于更快地发现代码问题,更快地发现关于代码的错误假设,并简化整体编码工作。
本文介绍如何使用函数进行基本的单元测试。 单元测试类和接口等高级概念,以及存根、模拟和测试工具的使用,虽然在笔记本的单元测试时也受支持,但不在本文的讨论范围之内。 本文也不会介绍其他种类的测试方法,例如集成测试、系统测试、验收测试,或者性能测试或可用性测试等非功能测试方法。
本文演示以下内容:
- 如何组织函数及其单元测试。
- 如何使用 Python、R、Scala 编写函数,以及使用 SQL 编写用户定义的函数,这些函数设计精良,可以进行单元测试。
- 如何从 Python、R、Scala 和 SQL 笔记本中调用这些函数。
- 如何通过适用于 Python 的常用测试框架 pytest、适用于 R 的 testthat 和适用于 Scala 的 ScalaTest 使用 Python、R 和 Scala 编写单元测试。 以及如何编写对 SQL 用户定义函数 (SQL UDF) 进行单元测试的 SQL。
- 如何从 Python、R、Scala 和 SQL 笔记本运行这些单元测试。
组织函数和单元测试
使用笔记本组织函数及其单元测试有一些常见方法。 每种方法都有其优势和挑战。
对于 Python、R 和 Scala 笔记本,常见的方法包括:
- 将函数及其单元测试存储在笔记本外部。
- 优势:可以在笔记本内外调用这些函数。 测试框架更适合在笔记本外部运行测试。
- 挑战:Scala 笔记本不支持这种方法。 此方法还增加了要跟踪和维护的文件数量。
- 将函数存储在一个笔记本中,将其单元测试存储在一个单独的笔记本中。
- 优势:这些函数更容易跨笔记本重复使用。
- 挑战:要跟踪和维护的笔记本数量增加。 这些函数不能在笔记本外部使用。 这些函数在笔记本外部也更难进行测试。
- 将函数及其单元测试存储在同一个笔记本中。
- 优势:函数及其单元测试存储在一个笔记本中,以便于跟踪和维护。
- 挑战:这些函数可能更难以跨笔记本重复使用。 这些函数不能在笔记本外部使用。 这些函数在笔记本外部也更难进行测试。
对于 Python 和 R 笔记本,Databricks 建议在笔记本外部存储函数及其单元测试。 对于 Scala 笔记本,Databricks 建议将函数包含在一个笔记本中,将其单元测试包含在一个单独的笔记本中。
对于 SQL 笔记本,Databricks 建议将函数作为 SQL 用户定义函数 (SQL UDF) 存储在架构(也称为数据库)中。 然后,可以从 SQL 笔记本调用这些 SQL UDF 及其单元测试。
编写函数
本节描述了一组简单的示例函数,它们确定以下内容:
- 表是否存在于数据库中。
- 某列是否存在于该表中。
- 对于该列中的某个值,该列中存在多少行。
这些函数比较简单,因此你可以专注于本文中详述的单元测试,而不必将注意力放在函数本身上。
为了获得最好的单元测试结果,一个函数应返回单一的可预测结果并且是单一的数据类型。 例如,要检查是否存在某些内容,该函数应返回布尔值 true 或 false。 要返回存在的行数,该函数应返回一个非负整数。 在第一个示例中,如果某内容不存在,则它应返回 false,如果它存在,应返回该内容本身。 同样,对于第二个示例,它应返回存在的行数,如果不存在行,则应返回 false。
可以如下所示在 Python、R、Scala 或 SQL 中将这些函数添加到现有的 Azure Databricks 工作区。
Python
以下代码假设你已设置 Databricks Git 文件夹 (Repos),添加了存储库,并在 Azure Databricks 工作区中打开了该存储库。
在存储库中创建一个名为 myfunctions.py
的文件,并将以下内容添加到该文件中。 本文中的其他示例要求此文件命名为 myfunctions.py
。 可以为自己的文件使用不同的名称。
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
.appName('integrity-tests') \
.getOrCreate()
# Does the specified table exist in the specified database?
def tableExists(tableName, dbName):
return spark.catalog.tableExists(f"{dbName}.{tableName}")
# Does the specified column exist in the given DataFrame?
def columnExists(dataFrame, columnName):
if columnName in dataFrame.columns:
return True
else:
return False
# How many rows are there for the specified value in the specified column
# in the given DataFrame?
def numRowsInColumnForValue(dataFrame, columnName, columnValue):
df = dataFrame.filter(col(columnName) == columnValue)
return df.count()
R
以下代码假设你已设置 Databricks Git 文件夹 (Repos),添加了存储库,并在 Azure Databricks 工作区中打开了该存储库。
在存储库中创建一个名为 myfunctions.r
的文件,并将以下内容添加到该文件中。 本文中的其他示例要求此文件命名为 myfunctions.r
。 可以为自己的文件使用不同的名称。
library(SparkR)
# Does the specified table exist in the specified database?
table_exists <- function(table_name, db_name) {
tableExists(paste(db_name, ".", table_name, sep = ""))
}
# Does the specified column exist in the given DataFrame?
column_exists <- function(dataframe, column_name) {
column_name %in% colnames(dataframe)
}
# How many rows are there for the specified value in the specified column
# in the given DataFrame?
num_rows_in_column_for_value <- function(dataframe, column_name, column_value) {
df = filter(dataframe, dataframe[[column_name]] == column_value)
count(df)
}
Scala
使用以下内容创建一个名为 myfunctions
的 Scala 笔记本。 本文中的其他示例要求将此笔记本命名为 myfunctions
。 可以为自己的笔记本使用不同的名称。
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
// Does the specified table exist in the specified database?
def tableExists(tableName: String, dbName: String) : Boolean = {
return spark.catalog.tableExists(dbName + "." + tableName)
}
// Does the specified column exist in the given DataFrame?
def columnExists(dataFrame: DataFrame, columnName: String) : Boolean = {
val nameOfColumn = null
for(nameOfColumn <- dataFrame.columns) {
if (nameOfColumn == columnName) {
return true
}
}
return false
}
// How many rows are there for the specified value in the specified column
// in the given DataFrame?
def numRowsInColumnForValue(dataFrame: DataFrame, columnName: String, columnValue: String) : Long = {
val df = dataFrame.filter(col(columnName) === columnValue)
return df.count()
}
SQL
以下代码假定在名为 default
的架构中拥有第三方示例数据集 diamonds,该架构位于名为 main
的目录中,可从 Azure Databricks 工作区访问。 如果要使用的目录或架构具有不同的名称,请更改以下一个或两个 USE
语句使其一致。
创建一个 SQL 笔记本并将以下内容添加到此新笔记本中。 然后将笔记本附加到群集并运行笔记本以将以下 SQL UDF 添加到指定的目录和架构。
注意
SQL UDF table_exists
和 column_exists
仅适用于 Unity Catalog。 Unity Catalog 的 SQL UDF 支持以公共预览版提供。
USE CATALOG main;
USE SCHEMA default;
CREATE OR REPLACE FUNCTION table_exists(catalog_name STRING,
db_name STRING,
table_name STRING)
RETURNS BOOLEAN
RETURN if(
(SELECT count(*) FROM system.information_schema.tables
WHERE table_catalog = table_exists.catalog_name
AND table_schema = table_exists.db_name
AND table_name = table_exists.table_name) > 0,
true,
false
);
CREATE OR REPLACE FUNCTION column_exists(catalog_name STRING,
db_name STRING,
table_name STRING,
column_name STRING)
RETURNS BOOLEAN
RETURN if(
(SELECT count(*) FROM system.information_schema.columns
WHERE table_catalog = column_exists.catalog_name
AND table_schema = column_exists.db_name
AND table_name = column_exists.table_name
AND column_name = column_exists.column_name) > 0,
true,
false
);
CREATE OR REPLACE FUNCTION num_rows_for_clarity_in_diamonds(clarity_value STRING)
RETURNS BIGINT
RETURN SELECT count(*)
FROM main.default.diamonds
WHERE clarity = clarity_value
调用函数
本节介绍调用上述函数的代码。 例如,可以使用这些函数来计算表中指定值存在于指定列中的行数。 但是,在继续之前,需要检查该表是否实际存在,以及该列是否实际存在于该表中。 以下代码检查这些条件。
如果将上一部分中的函数添加到 Azure Databricks 工作区,则可以从工作区调用这些函数,如下所示。
Python
在存储库中与上述 myfunctions.py
文件相同的文件夹中创建一个 Python 笔记本,并将以下内容添加到笔记本中。 根据需要更改表名称、架构(数据库)名称、列名称和列值的变量值。 然后将笔记本附加到群集并运行笔记本以查看结果。
from myfunctions import *
tableName = "diamonds"
dbName = "default"
columnName = "clarity"
columnValue = "VVS2"
# If the table exists in the specified database...
if tableExists(tableName, dbName):
df = spark.sql(f"SELECT * FROM {dbName}.{tableName}")
# And the specified column exists in that table...
if columnExists(df, columnName):
# Then report the number of rows for the specified value in that column.
numRows = numRowsInColumnForValue(df, columnName, columnValue)
print(f"There are {numRows} rows in '{tableName}' where '{columnName}' equals '{columnValue}'.")
else:
print(f"Column '{columnName}' does not exist in table '{tableName}' in schema (database) '{dbName}'.")
else:
print(f"Table '{tableName}' does not exist in schema (database) '{dbName}'.")
R
在存储库中与上述 myfunctions.r
文件相同的文件夹中创建一个 R 笔记本,并将以下内容添加到笔记本中。 根据需要更改表名称、架构(数据库)名称、列名称和列值的变量值。 然后将笔记本附加到群集并运行笔记本以查看结果。
library(SparkR)
source("myfunctions.r")
table_name <- "diamonds"
db_name <- "default"
column_name <- "clarity"
column_value <- "VVS2"
# If the table exists in the specified database...
if (table_exists(table_name, db_name)) {
df = sql(paste("SELECT * FROM ", db_name, ".", table_name, sep = ""))
# And the specified column exists in that table...
if (column_exists(df, column_name)) {
# Then report the number of rows for the specified value in that column.
num_rows = num_rows_in_column_for_value(df, column_name, column_value)
print(paste("There are ", num_rows, " rows in table '", table_name, "' where '", column_name, "' equals '", column_value, "'.", sep = ""))
} else {
print(paste("Column '", column_name, "' does not exist in table '", table_name, "' in schema (database) '", db_name, "'.", sep = ""))
}
} else {
print(paste("Table '", table_name, "' does not exist in schema (database) '", db_name, "'.", sep = ""))
}
Scala
在与之前的 myfunctions
Scala 笔记本相同的文件夹中创建另一个 Scala 笔记本,并将以下内容添加到此新笔记本中。
在此新笔记本的第一个单元格中,添加以下用于调用 %run magic 的代码。 此 magic 使 myfunctions
笔记本的内容可用于新笔记本。
%run ./myfunctions
在此新笔记本的第二个单元格中,添加以下代码。 根据需要更改表名称、架构(数据库)名称、列名称和列值的变量值。 然后将笔记本附加到群集并运行笔记本以查看结果。
val tableName = "diamonds"
val dbName = "default"
val columnName = "clarity"
val columnValue = "VVS2"
// If the table exists in the specified database...
if (tableExists(tableName, dbName)) {
val df = spark.sql("SELECT * FROM " + dbName + "." + tableName)
// And the specified column exists in that table...
if (columnExists(df, columnName)) {
// Then report the number of rows for the specified value in that column.
val numRows = numRowsInColumnForValue(df, columnName, columnValue)
println("There are " + numRows + " rows in '" + tableName + "' where '" + columnName + "' equals '" + columnValue + "'.")
} else {
println("Column '" + columnName + "' does not exist in table '" + tableName + "' in database '" + dbName + "'.")
}
} else {
println("Table '" + tableName + "' does not exist in database '" + dbName + "'.")
}
SQL
将以下代码添加到之前笔记本中的新单元格或单独笔记本中的单元格。 如有必要,更改架构或目录名称以与你的名称一致,然后运行此单元格以查看结果。
SELECT CASE
-- If the table exists in the specified catalog and schema...
WHEN
table_exists("main", "default", "diamonds")
THEN
-- And the specified column exists in that table...
(SELECT CASE
WHEN
column_exists("main", "default", "diamonds", "clarity")
THEN
-- Then report the number of rows for the specified value in that column.
printf("There are %d rows in table 'main.default.diamonds' where 'clarity' equals 'VVS2'.",
num_rows_for_clarity_in_diamonds("VVS2"))
ELSE
printf("Column 'clarity' does not exist in table 'main.default.diamonds'.")
END)
ELSE
printf("Table 'main.default.diamonds' does not exist.")
END
写入单元测试
本节介绍测试本文开头所述的每个函数的代码。 如果将来对函数进行任何更改,可以使用单元测试来确定这些函数是否正常工作。
如果将本文开头的函数添加到 Azure Databricks 工作区,则可以将这些函数的单元测试添加到工作区,如下所示。
Python
在存储库中与前面的 myfunctions.py
文件相同的文件夹中创建另一个名为 test_myfunctions.py
的文件,并将以下内容添加到文件中。 默认情况下,pytest
会查找名称以 test_
开头(或以 _test
结尾)的 .py
文件进行测试。 同样,默认情况下,pytest
在这些文件中查找名称以 test_
开头的函数进行测试。
一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是创建尽可能接近生产数据的虚假数据。 以下代码示例创建要对其运行单元测试的虚假数据。
import pytest
import pyspark
from myfunctions import *
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, FloatType, StringType
tableName = "diamonds"
dbName = "default"
columnName = "clarity"
columnValue = "SI2"
# Because this file is not a Databricks notebook, you
# must create a Spark session. Databricks notebooks
# create a Spark session for you by default.
spark = SparkSession.builder \
.appName('integrity-tests') \
.getOrCreate()
# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema = StructType([ \
StructField("_c0", IntegerType(), True), \
StructField("carat", FloatType(), True), \
StructField("cut", StringType(), True), \
StructField("color", StringType(), True), \
StructField("clarity", StringType(), True), \
StructField("depth", FloatType(), True), \
StructField("table", IntegerType(), True), \
StructField("price", IntegerType(), True), \
StructField("x", FloatType(), True), \
StructField("y", FloatType(), True), \
StructField("z", FloatType(), True), \
])
data = [ (1, 0.23, "Ideal", "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43 ), \
(2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31 ) ]
df = spark.createDataFrame(data, schema)
# Does the table exist?
def test_tableExists():
assert tableExists(tableName, dbName) is True
# Does the column exist?
def test_columnExists():
assert columnExists(df, columnName) is True
# Is there at least one row for the value in the specified column?
def test_numRowsInColumnForValue():
assert numRowsInColumnForValue(df, columnName, columnValue) > 0
R
在存储库中与前面的 myfunctions.r
文件相同的文件夹中创建另一个名为 test_myfunctions.r
的文件,并将以下内容添加到文件中。 默认情况下,testthat
会查找名称以 test
开头的 .r
文件进行测试。
一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是创建尽可能接近生产数据的虚假数据。 以下代码示例创建要对其运行单元测试的虚假数据。
library(testthat)
source("myfunctions.r")
table_name <- "diamonds"
db_name <- "default"
column_name <- "clarity"
column_value <- "SI2"
# Create fake data for the unit tests to run against.
# In general, it is a best practice to not run unit tests
# against functions that work with data in production.
schema <- structType(
structField("_c0", "integer"),
structField("carat", "float"),
structField("cut", "string"),
structField("color", "string"),
structField("clarity", "string"),
structField("depth", "float"),
structField("table", "integer"),
structField("price", "integer"),
structField("x", "float"),
structField("y", "float"),
structField("z", "float"))
data <- list(list(as.integer(1), 0.23, "Ideal", "E", "SI2", 61.5, as.integer(55), as.integer(326), 3.95, 3.98, 2.43),
list(as.integer(2), 0.21, "Premium", "E", "SI1", 59.8, as.integer(61), as.integer(326), 3.89, 3.84, 2.31))
df <- createDataFrame(data, schema)
# Does the table exist?
test_that ("The table exists.", {
expect_true(table_exists(table_name, db_name))
})
# Does the column exist?
test_that ("The column exists in the table.", {
expect_true(column_exists(df, column_name))
})
# Is there at least one row for the value in the specified column?
test_that ("There is at least one row in the query result.", {
expect_true(num_rows_in_column_for_value(df, column_name, column_value) > 0)
})
Scala
在与之前的 myfunctions
Scala 笔记本相同的文件夹中创建另一个 Scala 笔记本,并将以下内容添加到此新笔记本中。
在新笔记本的第一个单元格中,添加以下用于调用 %run
magic 的代码。 此 magic 使 myfunctions
笔记本的内容可用于新笔记本。
%run ./myfunctions
在第二个单元格中,添加以下代码。 此代码定义单元测试并指定如何运行它们。
一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是创建尽可能接近生产数据的虚假数据。 以下代码示例创建要对其运行单元测试的虚假数据。
import org.scalatest._
import org.apache.spark.sql.types.{StructType, StructField, IntegerType, FloatType, StringType}
import scala.collection.JavaConverters._
class DataTests extends AsyncFunSuite {
val tableName = "diamonds"
val dbName = "default"
val columnName = "clarity"
val columnValue = "SI2"
// Create fake data for the unit tests to run against.
// In general, it is a best practice to not run unit tests
// against functions that work with data in production.
val schema = StructType(Array(
StructField("_c0", IntegerType),
StructField("carat", FloatType),
StructField("cut", StringType),
StructField("color", StringType),
StructField("clarity", StringType),
StructField("depth", FloatType),
StructField("table", IntegerType),
StructField("price", IntegerType),
StructField("x", FloatType),
StructField("y", FloatType),
StructField("z", FloatType)
))
val data = Seq(
Row(1, 0.23, "Ideal", "E", "SI2", 61.5, 55, 326, 3.95, 3.98, 2.43),
Row(2, 0.21, "Premium", "E", "SI1", 59.8, 61, 326, 3.89, 3.84, 2.31)
).asJava
val df = spark.createDataFrame(data, schema)
// Does the table exist?
test("The table exists") {
assert(tableExists(tableName, dbName) == true)
}
// Does the column exist?
test("The column exists") {
assert(columnExists(df, columnName) == true)
}
// Is there at least one row for the value in the specified column?
test("There is at least one matching row") {
assert(numRowsInColumnForValue(df, columnName, columnValue) > 0)
}
}
nocolor.nodurations.nostacks.stats.run(new DataTests)
注意
此代码示例使用 ScalaTest 中的 FunSuite
测试样式。 有关其他可用的测试样式,请参阅为项目选择测试样式。
SQL
在添加单元测试之前请注意,在一般情况下,最佳做法是不要对处理生产数据的函数运行单元测试。 这对于添加、删除或以其他方式更改数据的函数尤其重要。 为了防止单元测试以意外的方式泄露生产数据,应该针对非生产数据运行单元测试。 一种常见做法是针对视图而不是表运行单元测试。
若要创建视图,可以从前一笔记本或单独的笔记本中的新单元格调用 CREATE VIEW 命令。 以下示例假设名为 main
的目录中名为 default
的架构(数据库)中有一个名为 diamonds
的现有表。 根据需要更改这些名称以便与你自己的名称匹配,然后仅运行该单元格。
USE CATALOG main;
USE SCHEMA default;
CREATE VIEW view_diamonds AS
SELECT * FROM diamonds;
创建视图后,将以下每个 SELECT
语句添加到前一笔记本或单独的笔记本中该语句原本所在的新单元格中。 根据需要更改名称以便与你自己的名称匹配。
SELECT if(table_exists("main", "default", "view_diamonds"),
printf("PASS: The table 'main.default.view_diamonds' exists."),
printf("FAIL: The table 'main.default.view_diamonds' does not exist."));
SELECT if(column_exists("main", "default", "view_diamonds", "clarity"),
printf("PASS: The column 'clarity' exists in the table 'main.default.view_diamonds'."),
printf("FAIL: The column 'clarity' does not exists in the table 'main.default.view_diamonds'."));
SELECT if(num_rows_for_clarity_in_diamonds("VVS2") > 0,
printf("PASS: The table 'main.default.view_diamonds' has at least one row where the column 'clarity' equals 'VVS2'."),
printf("FAIL: The table 'main.default.view_diamonds' does not have at least one row where the column 'clarity' equals 'VVS2'."));
运行单元测试
本节介绍如何运行在上一节中编码的单元测试。 运行单元测试时,会获得结果,其中显示哪些单元测试通过和失败。
如果将上一部分中的单元测试添加到了 Azure Databricks 工作区,则可以从工作区运行这些单元测试。 可以手动或按计划运行这些单元测试。
Python
在存储库中与上述 test_myfunctions.py
文件相同的文件夹中创建 Python 笔记本,并添加以下内容。
在新笔记本的第一个单元格中添加以下代码,然后运行该单元格以调用 %pip
magic。 此 magic 安装 pytest
。
%pip install pytest
在第二个单元格中,添加以下代码,然后运行该单元格。 结果显示哪些单元测试通过和失败。
import pytest
import sys
# Skip writing pyc files on a readonly filesystem.
sys.dont_write_bytecode = True
# Run pytest.
retcode = pytest.main([".", "-v", "-p", "no:cacheprovider"])
# Fail the cell execution if there are any test failures.
assert retcode == 0, "The pytest invocation failed. See the log for details."
R
在存储库中与上述 test_myfunctions.r
文件相同的文件夹中创建一个 R 笔记本,并添加以下内容。
在第一个单元格中添加以下代码,然后运行该单元格以调用 install.packages
函数。 此函数安装 testthat
。
install.packages("testthat")
在第二个单元格中,添加以下代码,然后运行该单元格。 结果显示哪些单元测试通过和失败。
library(testthat)
source("myfunctions.r")
test_dir(".", reporter = "tap")
Scala
运行上一节笔记本中的第一个和第二个单元格。 结果显示哪些单元测试通过和失败。
SQL
运行上一节笔记本中三个单元格中的每一个。 结果显示每个单元测试是通过还是失败。
如果在运行单元测试后不再需要该视图,可以删除该视图。 若要删除此视图,可将以下代码添加到前面某个笔记本中的新单元格中,然后仅运行该单元格。
DROP VIEW view_diamonds;
提示
可以在群集的驱动程序日志中查看笔记本运行结果(包括单元测试结果)。 还可为群集的日志传送指定一个位置。
可以设置持续集成和持续交付或部署 (CI/CD) 系统,例如 GitHub Actions,以便在代码发生更改时自动运行单元测试。 有关示例,请参阅笔记本的软件工程最佳做法中的 GitHub Actions 介绍。