Python user-defined table functions (UDTFs)
Important
This feature is in Public Preview in Databricks Runtime 14.3 LTS and above.
A user-defined table function (UDTF) allows you to register functions that return tables instead of scalar values. Unlike scalar functions that return a single result value from each call, each UDTF is invoked in a SQL statement's FROM
clause and returns an entire table as output.
Each UDTF call can accept zero or more arguments. These arguments can be scalar expressions or table arguments representing entire input tables.
Basic UDTF syntax
Apache Spark implements Python UDTFs as Python classes with a mandatory eval
method that uses yield
to
emit output rows.
To use your class as a UDTF, you must import the PySpark udtf
function. Databricks recommends
using this function as a decorator and explicitly specifying field names and types using the
returnType
option (unless the class defines an analyze
method as described in a later section).
The following UDTF creates a table using a fixed list of two integer arguments:
from pyspark.sql.functions import lit, udtf
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, x: int, y: int):
yield x + y, x - y
GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
| 3| -1|
+----+-----+
Register a UDTF
UDTFs are registered to the local SparkSession
and are isolated at the notebook or job level.
You cannot register UDTFs as objects in Unity Catalog, and UDTFs cannot be used with SQL warehouses.
You can register a UDTF to the current SparkSession
for use in SQL queries with the function spark.udtf.register()
. Provide a name for the SQL function and the Python UDTF class.
spark.udtf.register("get_sum_diff", GetSumDiff)
Call a registered UDTF
Once registered, you can use the UDTF in SQL using either the %sql
magic command or spark.sql()
function:
spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);
Use Apache Arrow
If your UDTF receives a small amount of data as input but outputs a large table, Databricks
recommends using Apache Arrow. You can enable it by specifying the useArrow
parameter when
declaring the UDTF:
@udtf(returnType="c1: int, c2: int", useArrow=True)
Variable argument lists - *args and **kwargs
You can use Python *args
or **kwargs
syntax and implement logic to handle an unspecified number of input values.
The following example returns the same result while explicitly checking the input length and types for the arguments:
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, *args):
assert(len(args) == 2)
assert(isinstance(arg, int) for arg in args)
x = args[0]
y = args[1]
yield x + y, x - y
GetSumDiff(lit(1), lit(2)).show()
Here is the same example, but using keyword arguments:
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, **kwargs):
x = kwargs["x"]
y = kwargs["y"]
yield x + y, x - y
GetSumDiff(x=lit(1), y=lit(2)).show()
Define a static schema at registration time
The UDTF returns rows with an output schema comprising an ordered sequence of column names and
types. If the UDTF schema should always remain the same for all queries, you can specify a static, fixed
schema after the @udtf
decorator. It must either be a StructType
:
StructType().add("c1", StringType())
Or a DDL string representing a struct type:
c1: string
Compute a dynamic schema at function call time
UDTFs can also compute the output schema programmatically for each call depending on the values of the input arguments. To do this, define a static method called analyze
that accepts zero or more parameters that correspond to the arguments provided to the specific UDTF call.
Each argument of the analyze
method is an instance of the AnalyzeArgument
class which contains the following fields:
AnalyzeArgument class field |
Description |
---|---|
dataType |
The type of the input argument as a DataType . For input table arguments, this is a StructType representing the table's columns. |
value |
The value of the input argument as an Optional[Any] . This is None for table arguments or literal scalar arguments that are not constant. |
isTable |
Whether the input argument is a table as a BooleanType . |
isConstantExpression |
Whether the input argument is a constant-foldable expression as a BooleanType . |
The analyze
method returns an instance of the AnalyzeResult
class, which includes the result table's schema as a StructType
plus some optional fields. If the UDTF accepts an input table argument, then the AnalyzeResult
can also include a requested way to partition and order the rows of the input table across several UDTF calls, as described later.
AnalyzeResult class field |
Description |
---|---|
schema |
The schema of the result table as a StructType . |
withSinglePartition |
Whether to send all input rows to the same UDTF class instance as a BooleanType . |
partitionBy |
If set to non-empty, all rows with each unique combination of values of the partitioning expressions are consumed by a separate instance of the UDTF class. |
orderBy |
If set to non-empty, this specifies an ordering of rows within each partition. |
select |
If set to non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to evaluate against the columns in the input TABLE argument. The UDTF receives one input attribute for each name in the list in the order they are listed. |
This analyze
example returns one output column for each word in the input string argument.
@udtf
class MyUDTF:
@staticmethod
def analyze(text: AnalyzeArgument) -> AnalyzeResult:
schema = StructType()
for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
schema = schema.add(f"word_{index}", IntegerType())
return AnalyzeResult(schema=schema)
def eval(self, text: str):
counts = {}
for word in text.split(" "):
if word not in counts:
counts[word] = 0
counts[word] += 1
result = []
for word in sorted(list(set(text.split(" ")))):
result.append(counts[word])
yield result
['word_0', 'word_1']
Forward state to future eval
calls
The analyze
method can serve as a convenient place to perform initialization and then forward the results to future eval
method invocations for the same UDTF call.
To do so, create a subclass of AnalyzeResult
and return an instance of the subclass from the analyze
method.
Then, add an additional argument to the __init__
method to accept that instance.
This analyze
example returns a constant output schema, but adds custom information in the result metadata to be consumed by future __init__
method calls:
@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
buffer: str = ""
@udtf
class TestUDTF:
def __init__(self, analyze_result=None):
self._total = 0
if analyze_result is not None:
self._buffer = analyze_result.buffer
else:
self._buffer = ""
@staticmethod
def analyze(argument, _) -> AnalyzeResult:
if (
argument.value is None
or argument.isTable
or not isinstance(argument.value, str)
or len(argument.value) == 0
):
raise Exception("The first argument must be a non-empty string")
assert argument.dataType == StringType()
assert not argument.isTable
return AnalyzeResultWithBuffer(
schema=StructType()
.add("total", IntegerType())
.add("buffer", StringType()),
withSinglePartition=True,
buffer=argument.value,
)
def eval(self, argument, row: Row):
self._total += 1
def terminate(self):
yield self._total, self._buffer
self.spark.udtf.register("test_udtf", TestUDTF)
spark.sql(
"""
WITH t AS (
SELECT id FROM range(1, 21)
)
SELECT total, buffer
FROM test_udtf("abc", TABLE(t))
"""
).show()
+-------+-------+
| count | buffer|
+-------+-------+
| 20 | "abc"|
+-------+-------+
Yield output rows
The eval
method runs once for each row of the input table argument (or just once if no table argument is provided), followed by one invocation of the terminate
method at the end. Either method outputs zero or more rows that conform to the result schema by yielding tuples, lists, or pyspark.sql.Row
objects.
This example returns a row by providing a tuple of three elements:
def eval(self, x, y, z):
yield (x, y, z)
You can also omit the parentheses:
def eval(self, x, y, z):
yield x, y, z
Add a trailing comma to return a row with only one column:
def eval(self, x, y, z):
yield x,
You can also yield a pyspark.sql.Row
object.
def eval(self, x, y, z)
from pyspark.sql.types import Row
yield Row(x, y, z)
This example yields output rows from the terminate
method using a Python list. You can store state inside the class from earlier steps in the UDTF evaluation for this purpose.
def terminate(self):
yield [self.x, self.y, self.z]
Pass scalar arguments to a UDTF
You can pass scalar arguments to a UDTF as constant expressions comprising literal values or functions based on them. For example:
SELECT * FROM udtf(42, group => upper("finance_department"));
Pass table arguments to a UDTF
Python UDTFs can accept an input table as an argument in addition to scalar input arguments. A single UDTF can also accept a table argument and multiple scalar arguments.
Then any SQL query can provide an input table using the TABLE
keyword followed by parentheses
surrounding an appropriate table identifier, like TABLE(t)
. Alternatively, you can pass a table
subquery, like TABLE(SELECT a, b, c FROM t)
or
TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))
.
The input table argument is then represented as a pyspark.sql.Row
argument to the eval
method,
with one call to the eval
method for each row in the input table. You can use standard PySpark
column field annotations to interact with columns in each row. The following example demonstrates
explicitly importing the PySpark Row
type and then filtering the passed table on the id
field:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="id: int")
class FilterUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],
spark.udtf.register("filter_udtf", FilterUDTF)
To query the function, use the TABLE
SQL keyword:
SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
| 6|
| 7|
| 8|
| 9|
+---+
Specify a partitioning of the input rows from function calls
When calling a UDTF with a table argument, any SQL query can partition the input table across several UDTF calls based on the values of one or more input table columns.
To specify a partition, use the PARTITION BY
clause in the function call after the TABLE
argument.
This guarantees that all input rows with each unique combination of values of the
partitioning columns will get consumed by exactly one instance of the UDTF class.
Note that in addition to simple column references, the PARTITION BY
clause also accepts arbitrary
expressions based on input table columns. For example, you can specify the LENGTH
of a
string, extract a month from a date, or concatenate two values.
It is also possible to specify WITH SINGLE PARTITION
instead of PARTITION BY
to request only
one partition wherein all input rows must be consumed by exactly one instance of the UDTF class.
Within each partition, you can optionally specify a required ordering of the input rows as the
UDTF's eval
method consumes them. To do so, provide an ORDER BY
clause after the
PARTITION BY
or WITH SINGLE PARTITION
clause described above.
For example, consider the following UDTF:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="a: string, b: int")
class FilterUDTF:
def __init__(self):
self.key = ""
self.max = 0
def eval(self, row: Row):
self.key = row["a"]
self.max = max(self.max, row["b"])
def terminate(self):
yield self.key, self.max
spark.udtf.register("filter_udtf", FilterUDTF)
You can specify partitioning options when calling the UDTF over the input table in muliple ways:
-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
| a | b |
+-------+----+
| "abc" | 2 |
| "abc" | 4 |
| "def" | 6 |
| "def" | 8 |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
| a | b |
+-------+----+
| "abc" | 4 |
| "def" | 8 |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
| a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
| a | b |
+-------+----+
| "def" | 8 |
+-------+----+
Specify a partitioning of the input rows from the analyze
method
Note that for each of the above ways of partitioning the input table when calling UDTFs in SQL queries, there is a corresponding way for the UDTF's analyze
method to specify the same partitioning method automatically instead.
- Instead of calling a UDTF as
SELECT * FROM udtf(TABLE(t) PARTITION BY a)
, you can update theanalyze
method to set the fieldpartitionBy=[PartitioningColumn("a")]
and simply call the function usingSELECT * FROM udtf(TABLE(t))
. - By the same token, instead of specifying
TABLE(t) WITH SINGLE PARTITION ORDER BY b
in the SQL query, you can makeanalyze
set the fieldswithSinglePartition=true
andorderBy=[OrderingColumn("b")]
and then just passTABLE(t)
. - Instead of passing
TABLE(SELECT a FROM t)
in the SQL query, you can makeanalyze
setselect=[SelectedColumn("a")]
and then just passTABLE(t)
.
In the following example, analyze
returns a constant output schema, selects a subset of columns from the input table, and specifies that the input table is partitioned across several UDTF calls based on the values of the date
column:
@staticmethod
def analyze(*args) -> AnalyzeResult:
"""
The input table will be partitioned across several UDTF calls based on the monthly
values of each `date` column. The rows within each partition will arrive ordered by the `date`
column. The UDTF will only receive the `date` and `word` columns from the input table.
"""
from pyspark.sql.functions import (
AnalyzeResult,
OrderingColumn,
PartitioningColumn,
)
assert len(args) == 1, "This function accepts one argument only"
assert args[0].isTable, "Only table arguments are supported"
return AnalyzeResult(
schema=StructType()
.add("month", DateType())
.add('longest_word", IntegerType()),
partitionBy=[
PartitioningColumn("extract(month from date)")],
orderBy=[
OrderingColumn("date")],
select=[
SelectedColumn("date"),
SelectedColumn(
name="length(word),
alias="length_word")])