Python 用户定义表函数 (UDTF)

重要

此功能在 Databricks Runtime 14.3 LTS 及更高版本中作为公共预览版提供。

用户定义的表函数 (UDTF) 允许注册返回表而不是标量值的函数。 与每次调用都返回单个结果值的标量函数不同,每个 UDTF 都在 SQL 语句的 FROM 子句中调用,并返回整个表作为输出。

每个 UDTF 调用都可以接受零个或多个参数。 这些参数可以是标量表达式或代表整个输入表的表参数。

基本 UDTF 语法

Apache Spark 通过必需的 eval 方法使用 yield 发出输出行来将 Python UDDF 实现为 Python 类。

若要将类用作 UDTF,必须导入 PySpark udtf 函数。 Databricks 建议将此函数用作修饰器,并且使用 returnType 选项显式指定字段名称和类型(除非类定义 analyze 方法,如后面的部分所述)。

以下 UDTF 使用两个整数参数的固定列表创建一个表:

from pyspark.sql.functions import lit, udtf

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
|   3|   -1|
+----+-----+

注册 UDTF

UDTF 将注册到本地 SparkSession,并在笔记本或作业级别隔离。

不能将 UDTF 注册为 Unity Catalog 中的对象,UDDF 不能用于 SQL 仓库。

可以将 UDTF 注册到当前 SparkSession,以便使用函数 spark.udtf.register() 进行 SQL 查询。 提供 SQL 函数和 Python UDTF 类的名称。

spark.udtf.register("get_sum_diff", GetSumDiff)

调用已注册的 UDTF

注册后,可以使用 %sql magic 命令或 spark.sql() 函数在 SQL 中使用 UDTF:

spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);

使用 Apache Arrow

如果 UDTF 接收少量数据作为输入,但输出大型表,则 Databricks 建议使用 Apache Arrow。 可以通过在声明 UDTF 时指定 useArrow 参数来启用它:

@udtf(returnType="c1: int, c2: int", useArrow=True)

变量参数列表 - *args 和 **kwargs

可以使用 Python *args**kwargs 语法并实现逻辑来处理未指定数量的输入值。

以下示例会返回相同的结果,同时显式检查参数的输入长度和类型:

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, *args):
        assert(len(args) == 2)
        assert(isinstance(arg, int) for arg in args)
        x = args[0]
        y = args[1]
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()

下面是相同的示例,但使用了关键字参数:

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, **kwargs):
        x = kwargs["x"]
        y = kwargs["y"]
        yield x + y, x - y

GetSumDiff(x=lit(1), y=lit(2)).show()

在注册时定义静态架构

UDTF 返回带有输出架构的行,该架构由列名和类型的有序序列组成。 如果 UDTF 架构对于所有查询都应始终保持不变,则可以在 @udtf 装饰器后指定静态固定架构。 它必须是 StructType

StructType().add("c1", StringType())

或表示结构类型的 DDL 字符串:

c1: string

在函数调用时计算动态架构

UDDF 还可以根据输入参数的值以编程方式计算每个调用的输出架构。 为此,请定义一个名为 analyze 的静态方法,该方法接受与提供给特定 UDTF 调用的参数对应的零个或多个参数。

analyze 方法的每个参数都是 AnalyzeArgument 类的实例,其中包含以下字段:

AnalyzeArgument 类字段 说明
dataType 作为 DataType 的输入参数的类型。 对于输入表参数,这是一个 StructType,表示表的列。
value 作为 Optional[Any] 的输入参数的值。 对于非常数表参数或文本标量参数,这是 None
isTable 输入参数是否是作为 BooleanType 的表。
isConstantExpression 输入参数是否是作为 BooleanType 的常数可折叠表达式。

analyze 方法返回 AnalyzeResult 类的实例,其中包括结果表的架构(作为 StructType)以及一些可选字段。 如果 UDTF 接受输入表参数,则 AnalyzeResult 还可以包含一种请求的方法,用于在多个 UDTF 调用中对输入表的行进行分区和排序,如下文所述。

AnalyzeResult 类字段 说明
schema 作为 StructType 的结果表的架构。
withSinglePartition 是否将所有输入行发送到作为 BooleanType 的 UDTF 类实例。
partitionBy 如果设置为非空,则具有分区表达式的每个唯一值组合的所有行都将由 UDTF 类的单独实例使用。
orderBy 如果设置为非空,则指定每个分区中的行的顺序。
select 如果设置为非空,则这是 UDTF 为 Catalyst 指定的表达式序列,用于根据输入 TABLE 参数中的列进行评估。 UDTF 按列出的顺序接收列表中的每个名称的一个输入属性。

analyze 示例为输入字符串参数中的每个单词返回一个输出列。

@udtf
class MyUDTF:
  @staticmethod
  def analyze(text: AnalyzeArgument) -> AnalyzeResult:
    schema = StructType()
    for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
      schema = schema.add(f"word_{index}", IntegerType())
    return AnalyzeResult(schema=schema)

  def eval(self, text: str):
    counts = {}
    for word in text.split(" "):
      if word not in counts:
            counts[word] = 0
      counts[word] += 1
    result = []
    for word in sorted(list(set(text.split(" ")))):
      result.append(counts[word])
    yield result
['word_0', 'word_1']

将状态转发到将来的 eval 调用

analyze 方法可用作执行初始化的便捷位置,然后将结果转发给同一 UDTF 调用的未来 eval 方法调用。

为此,请创建 AnalyzeResult 的子类,并从 analyze 方法返回该子类的实例。 然后,将附加参数添加到 __init__ 方法以接受该实例。

analyze 示例返回常数输出架构,但在结果元数据中添加自定义信息,供未来 __init__ 方法调用使用:

@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
    buffer: str = ""

@udtf
class TestUDTF:
  def __init__(self, analyze_result=None):
    self._total = 0
    if analyze_result is not None:
      self._buffer = analyze_result.buffer
    else:
      self._buffer = ""

  @staticmethod
  def analyze(argument, _) -> AnalyzeResult:
    if (
      argument.value is None
      or argument.isTable
      or not isinstance(argument.value, str)
      or len(argument.value) == 0
    ):
      raise Exception("The first argument must be a non-empty string")
    assert argument.dataType == StringType()
    assert not argument.isTable
    return AnalyzeResultWithBuffer(
      schema=StructType()
        .add("total", IntegerType())
        .add("buffer", StringType()),
      withSinglePartition=True,
      buffer=argument.value,
    )

  def eval(self, argument, row: Row):
    self._total += 1

  def terminate(self):
    yield self._total, self._buffer

self.spark.udtf.register("test_udtf", TestUDTF)

spark.sql(
  """
  WITH t AS (
    SELECT id FROM range(1, 21)
  )
  SELECT total, buffer
  FROM test_udtf("abc", TABLE(t))
  """
).show()
+-------+-------+
| count | buffer|
+-------+-------+
|    20 |  "abc"|
+-------+-------+

生成输出行

eval 方法针对输入表参数的每一行运行一次(如果未提供任何表参数,则只运行一次),然后在末尾调用一次 terminate 方法。 方法通过生成元组、列表或 pyspark.sql.Row 对象来输出符合结果架构的零行或更多行。

此示例通过提供三个元素的元组返回行:

def eval(self, x, y, z):
  yield (x, y, z)

还可以省略括号:

def eval(self, x, y, z):
  yield x, y, z

添加尾随逗号以返回仅包含一列的行:

def eval(self, x, y, z):
  yield x,

还可以生成 pyspark.sql.Row 对象。

def eval(self, x, y, z)
  from pyspark.sql.types import Row
  yield Row(x, y, z)

此示例使用 Python 列表从 terminate 方法生成输出行。 为此,可以将状态存储在 UDTF 评估早期步骤的类中。

def terminate(self):
  yield [self.x, self.y, self.z]

将标量参数传递给 UDTF

可以将标量参数作为由文本值或基于它们的函数组成的常数表达式传递给 UDTF。 例如:

SELECT * FROM udtf(42, group => upper("finance_department"));

将表参数传递给 UDTF

除了标量输入参数外,Python UDF 还可以接受输入表作为参数。 单个 UDTF 还可以接受表参数和多个标量参数。

然后,任何 SQL 查询都可以使用 TABLE 关键字提供输入表,后跟括号中相应的表标识符,例如 TABLE(t)。 或者,可以传递表子查询,例如 TABLE(SELECT a, b, c FROM t)TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))

然后,输入表参数表示为 eval 方法的 pyspark.sql.Row 参数,对输入表中每一行的 eval 方法进行一次调用。 可以使用标准 PySpark 列字段注释与每行中的列进行交互。 以下示例演示如何显式导入 PySpark Row 类型,然后在 id 字段中筛选传递的表:

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

若要查询函数,请使用 TABLE SQL 关键字:

SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
|  6|
|  7|
|  8|
|  9|
+---+

从函数调用中指定输入行的分区

使用表参数调用 UDTF 时,任何 SQL 查询都可以根据一个或多个输入表列的值跨多个 UDTF 调用对输入表进行分区。

若要指定分区,请在 TABLE 参数之后的函数调用中使用 PARTITION BY 子句。 这可以保证具有分区列值的每个唯一组合的所有输入行都由 UDTF 类的一个实例使用。

请注意,除了简单的列引用外,PARTITION BY 子句还接受基于输入表列的任意表达式。 例如,可以指定字符串的 LENGTH、从日期中提取月份或连接两个值。

还可以指定 WITH SINGLE PARTITION 而不是 PARTITION BY 以仅请求一个分区,其中所有输入行必须由 UDTF 类的一个实例使用。

在每个分区中,可以选择指定输入行的必需顺序,因为 UDTF 的 eval 方法会使用它们。 为此,请在上述 PARTITION BYWITH SINGLE PARTITION 子句后提供 ORDER BY 子句。

例如,请考虑以下 UDTF:

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="a: string, b: int")
class FilterUDTF:
  def __init__(self):
    self.key = ""
    self.max = 0

  def eval(self, row: Row):
    self.key = row["a"]
    self.max = max(self.max, row["b"])

  def terminate(self):
    yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

可以通过多种方法在输入表上调用 UDF 时指定分区选项:

-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 2  |
| "abc" | 4  |
| "def" | 6  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 4  |
| "def" | 8  |
+-------+----+

-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
|     a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "def" | 8 |
+-------+----+

通过 analyze 方法指定输入行的分区

请注意,对于在 SQL 查询中调用 UDTF 对输入表进行分区的上述每种方式,可以改为使用 UDTF 的 analyze 方法的相应方式来自动指定相同的分区方法。

  • 可以更新 analyze 方法来设置字段 partitionBy=[PartitioningColumn("a")],然后简单地使用 SELECT * FROM udtf(TABLE(t)) 调用函数,而不是将 UDTF 作为 SELECT * FROM udtf(TABLE(t) PARTITION BY a) 调用。
  • 通过相同的标记,你无需在 SQL 查询中指定 TABLE(t) WITH SINGLE PARTITION ORDER BY b,而可以使 analyze 设置字段 withSinglePartition=trueorderBy=[OrderingColumn("b")],并仅传递 TABLE(t)
  • 无需在 SQL 查询中传递 TABLE(SELECT a FROM t),而是通过 analyze 设置 select=[SelectedColumn("a")],然后只传递 TABLE(t)

在以下示例中,analyze 返回常数输出架构,从输入表中选择列的子集,并指定输入表根据 date 列的值在多个 UDTF 调用中进行分区:

@staticmethod
def analyze(*args) -> AnalyzeResult:
  """
  The input table will be partitioned across several UDTF calls based on the monthly
  values of each `date` column. The rows within each partition will arrive ordered by the `date`
  column. The UDTF will only receive the `date` and `word` columns from the input table.
  """
  from pyspark.sql.functions import (
    AnalyzeResult,
    OrderingColumn,
    PartitioningColumn,
  )

  assert len(args) == 1, "This function accepts one argument only"
  assert args[0].isTable, "Only table arguments are supported"
  return AnalyzeResult(
    schema=StructType()
      .add("month", DateType())
      .add('longest_word", IntegerType()),
    partitionBy=[
      PartitioningColumn("extract(month from date)")],
    orderBy=[
      OrderingColumn("date")],
    select=[
      SelectedColumn("date"),
      SelectedColumn(
        name="length(word),
        alias="length_word")])