共用方式為

Unity 目录中的 Python 用户定义的表函数 (UDF)

重要

在 Unity 目录中注册 Python UDF 以 公共预览版提供

Unity Catalog 用户定义的表函数(UDTF)注册的是返回完整表而非标量值的函数。 与从每个调用返回单个结果值的标量函数不同,UDDF 在 SQL 语句的 FROM 子句中调用,并可以返回多个行和列。

UDF 特别适用于:

  • 将数组或复杂数据结构转换为多行
  • 将外部 API 或服务集成到 SQL 工作流中
  • 实现自定义数据生成或扩充逻辑
  • 处理需要跨行执行有状态作的数据

每个 UDTF 调用都接受零个或多个参数。 这些参数可以是表示整个输入表的标量表达式或表参数。

可以通过两种方式注册 UDF:

要求

以下计算类型支持 Unity 目录 Python UDF:

  • 具有标准访问模式的经典计算(Databricks Runtime 17.1 及更高版本)
  • SQL 数据仓库(专业版)

在 Unity 目录中创建 UDTF

使用 SQL DDL 在 Unity 目录中创建受治理的 UDTF。 UDF 是使用 SQL 语句的子句调用的 FROM

CREATE OR REPLACE FUNCTION square_numbers(start INT, end INT)
RETURNS TABLE (num INT, squared INT)
LANGUAGE PYTHON
HANDLER 'SquareNumbers'
DETERMINISTIC
AS $$
class SquareNumbers:
    """
    Basic UDTF that computes a sequence of integers
    and includes the square of each number in the range.
    """
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)
$$;

SELECT * FROM square_numbers(1, 5);

+-----+---------+
| num | squared |
+-----+---------+
| 1   | 1       |
| 2   | 4       |
| 3   | 9       |
| 4   | 16      |
| 5   | 25      |
+-----+---------+

Azure Databricks 将 Python UDTF 实现为 Python 类,并使用一个必需的方法eval来生成输出行。

表参数

注释

Databricks Runtime 17.2 及更高版本中支持 TABLE 参数。

UDTF 可以接受整个表作为输入参数,实现复杂的有状态转换和聚合。

eval()terminate() 生命周期方法

UDF 中的表参数使用以下函数来处理每一行:

  • eval():为输入表中的每一行调用一次。 这是主要的处理方法,是必需的。
  • terminate():在每个分区的末尾调用一次,之后所有行都经过 eval()处理。 使用此方法生成最终聚合结果或执行清理作。 此方法是可选的,但对于有状态作(如聚合、计数或批处理)至关重要。

有关eval()terminate()方法的详细信息,请参阅Apache Spark 文档:Python UDTF

行访问模式

eval() 将 TABLE 参数中的行作为 pyspark.sql.Row 对象接收。 可以按列名(row['id']row['name'])或索引(row[0]row[1]来访问值。

  • 架构灵活性:声明不包含架构定义的 TABLE 参数(例如,data TABLEt TABLE)。 该函数接受任何表结构,因此代码应验证所需的列是否存在。

请参阅 示例:将 IP 地址与 CIDR 网络块匹配 ,示例 :使用 Azure Databricks 视觉终结点进行批处理图像字幕

环境隔离

注释

共享隔离环境需要 Databricks Runtime 17.2 及更高版本。 在早期版本中,所有 Unity 目录 Python UDDF 都以严格的隔离模式运行。

默认情况下,具有相同所有者和会话的 Unity 目录 Python UDF 可以共享隔离环境。 这通过减少需要启动的单独环境的数量来提高性能并减少内存使用量。

严格隔离

若要确保 UDTF 始终在其自己的完全隔离环境中运行,请添加 STRICT ISOLATION 特征子句。

大多数 UDF 不需要严格的隔离。 标准数据处理 UDF 受益于默认共享隔离环境,运行速度更快,内存消耗较低。

STRICT ISOLATION 特征子句添加到 UDDF,该子句:

  • 使用eval()exec()或类似函数以代码形式运行输入。
  • 将文件写入本地文件系统。
  • 修改全局变量或系统状态。
  • 访问或修改环境变量。

以下 UDTF 示例设置自定义环境变量,回读变量,并使用变量将一组数字相乘。 由于 UDTF 会改变进程环境,因此请在其中 STRICT ISOLATION运行它。 否则,它可能会泄漏或替代同一环境中其他 UDF/UDF 的环境变量,从而导致行为不正确。

CREATE OR REPLACE TEMPORARY FUNCTION multiply_numbers(factor STRING)
RETURNS TABLE (original INT, scaled INT)
LANGUAGE PYTHON
STRICT ISOLATION
HANDLER 'Multiplier'
AS $$
import os

class Multiplier:
    def eval(self, factor: str):
        # Save the factor as an environment variable
        os.environ["FACTOR"] = factor

        # Read it back and convert it to a number
        scale = int(os.getenv("FACTOR", "1"))

        # Multiply 0 through 4 by the factor
        for i in range(5):
            yield (i, i * scale)
$$;

SELECT * FROM multiply_numbers("3");

如果您的函数产生一致的结果,请设置DETERMINISTIC

如果您的函数在相同输入下生成相同输出,请将 DETERMINISTIC 添加到函数定义中。 这允许查询优化来提高性能。

默认情况下,Batch Unity Catalog Python UDTF 被假定为非确定性的,除非已被显式声明。 非确定性函数的示例包括:生成随机值、访问当前时间或日期或进行外部 API 调用。

请参阅 CREATE FUNCTION (SQL 和 Python)。

实例

以下示例演示了 Unity 目录 Python UDDF 的实际用例,从简单的数据转换到复杂的外部集成。

示例:重新实现 explode

虽然 Spark 提供内置 explode 函数,但创建自己的版本演示了采用单个输入并生成多个输出行的基本 UDTF 模式。

CREATE OR REPLACE FUNCTION my_explode(arr ARRAY<STRING>)
RETURNS TABLE (element STRING)
LANGUAGE PYTHON
HANDLER 'MyExplode'
DETERMINISTIC
AS $$
class MyExplode:
    def eval(self, arr):
        if arr is None:
            return
        for element in arr:
            yield (element,)
$$;

直接在 SQL 查询中使用函数:

SELECT element FROM my_explode(array('apple', 'banana', 'cherry'));
+---------+
| element |
+---------+
| apple   |
| banana  |
| cherry  |
+---------+

或者,将其应用于具有 LATERAL 联接的现有表数据:

SELECT s.*, e.element
FROM my_items AS s,
LATERAL my_explode(s.items) AS e;

示例:通过 REST API 的 IP 地址地理位置

此示例演示 UDDF 如何将外部 API 直接集成到 SQL 工作流中。 分析师可以使用熟悉的 SQL 语法通过实时 API 调用来丰富数据,而无需单独的 ETL 进程。

CREATE OR REPLACE FUNCTION ip_to_location(ip_address STRING)
RETURNS TABLE (city STRING, country STRING)
LANGUAGE PYTHON
HANDLER 'IPToLocationAPI'
AS $$
class IPToLocationAPI:
    def eval(self, ip_address):
        import requests
        api_url = f"https://api.ip-lookup.example.com/{ip_address}"
        try:
            response = requests.get(api_url)
            response.raise_for_status()
            data = response.json()
            yield (data.get('city'), data.get('country'))
        except requests.exceptions.RequestException as e:
            # Return nothing if the API request fails
            return
$$;

注释

使用配置了标准访问模式的计算时,Python UDF 允许通过端口 80、443 和 53 的 TCP/UDP 网络流量。

使用函数通过地理信息丰富 Web 日志数据:

SELECT
  l.timestamp,
  l.request_path,
  geo.city,
  geo.country
FROM web_logs AS l,
LATERAL ip_to_location(l.ip_address) AS geo;

此方法支持实时地理分析,而无需预先处理的查阅表或单独的数据管道。 UDTF 处理 HTTP 请求、JSON 分析和错误处理,使外部数据源可通过标准 SQL 查询访问。

示例:将 IP 地址与 CIDR 网络块匹配

此示例演示了将 IP 地址与 CIDR 网络块匹配,这是一项需要复杂 SQL 逻辑的常见数据工程任务。

首先,使用 IPv4 和 IPv6 地址创建示例数据:

-- An example IP logs with both IPv4 and IPv6 addresses
CREATE OR REPLACE TEMPORARY VIEW ip_logs AS
VALUES
  ('log1', '192.168.1.100'),
  ('log2', '10.0.0.5'),
  ('log3', '172.16.0.10'),
  ('log4', '8.8.8.8'),
  ('log5', '2001:db8::1'),
  ('log6', '2001:db8:85a3::8a2e:370:7334'),
  ('log7', 'fe80::1'),
  ('log8', '::1'),
  ('log9', '2001:db8:1234:5678::1')
t(log_id, ip_address);

接下来,定义并注册 UDTF。 请注意 Python 类结构:

  • t TABLE 参数接受具有任何架构的输入表。 UDTF 会自动适应以处理所提供的任何列。 这种灵活性意味着可以在不同表中使用相同的函数,而无需修改函数签名。 但是,必须仔细检查行的架构以确保兼容性。
  • 该方法 __init__ 用于繁重的一次性设置,例如加载大型网络列表。 此工作在每个输入表的分区中执行一次。
  • 该方法 eval 处理每行并包含核心匹配逻辑。 此方法对输入分区中的每个行执行一次,每个执行由该分区的 IpMatcher UDTF 类的相应实例执行。
  • HANDLER 子句指定实现 UDTF 逻辑的 Python 类的名称。
CREATE OR REPLACE TEMPORARY FUNCTION ip_cidr_matcher(t TABLE)
RETURNS TABLE(log_id STRING, ip_address STRING, network STRING, ip_version INT)
LANGUAGE PYTHON
HANDLER 'IpMatcher'
COMMENT 'Match IP addresses against a list of network CIDR blocks'
AS $$
class IpMatcher:
    def __init__(self):
        import ipaddress
        # Heavy initialization - load networks once per partition
        self.nets = []
        cidrs = ['192.168.0.0/16', '10.0.0.0/8', '172.16.0.0/12',
                 '2001:db8::/32', 'fe80::/10', '::1/128']
        for cidr in cidrs:
            self.nets.append(ipaddress.ip_network(cidr))

    def eval(self, row):
        import ipaddress
	    # Validate that required fields exist
        required_fields = ['log_id', 'ip_address']
        for field in required_fields:
            if field not in row:
                raise ValueError(f"Missing required field: {field}")
        try:
            ip = ipaddress.ip_address(row['ip_address'])
            for net in self.nets:
                if ip in net:
                    yield (row['log_id'], row['ip_address'], str(net), ip.version)
                    return
            yield (row['log_id'], row['ip_address'], None, ip.version)
        except ValueError:
            yield (row['log_id'], row['ip_address'], 'Invalid', None)
$$;

现在已在 ip_cidr_matcher Unity 目录中注册,请使用 TABLE() 语法直接从 SQL 调用它:

-- Process all IP addresses
SELECT
  *
FROM
  ip_cidr_matcher(t => TABLE(ip_logs))
ORDER BY
  log_id;
+--------+-------------------------------+-----------------+-------------+
| log_id | ip_address                    | network         | ip_version  |
+--------+-------------------------------+-----------------+-------------+
| log1   | 192.168.1.100                 | 192.168.0.0/16  | 4           |
| log2   | 10.0.0.5                      | 10.0.0.0/8      | 4           |
| log3   | 172.16.0.10                   | 172.16.0.0/12   | 4           |
| log4   | 8.8.8.8                       | null            | 4           |
| log5   | 2001:db8::1                   | 2001:db8::/32   | 6           |
| log6   | 2001:db8:85a3::8a2e:370:7334  | 2001:db8::/32   | 6           |
| log7   | fe80::1                       | fe80::/10       | 6           |
| log8   | ::1                           | ::1/128         | 6           |
| log9   | 2001:db8:1234:5678::1         | 2001:db8::/32   | 6           |
+--------+-------------------------------+-----------------+-------------+

示例:使用 Azure Databricks 视觉终结点进行批处理图像说明

此示例演示了使用 Azure Databricks 视觉模型服务终结点进行批量图像说明。 它展示了如何使用 terminate() 进行批处理和基于分区的执行。

  1. 创建具有公共映像 URL 的表:

    CREATE OR REPLACE TEMPORARY VIEW sample_images AS
    VALUES
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg', 'scenery'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Camponotus_flavomarginatus_ant.jpg/1024px-Camponotus_flavomarginatus_ant.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Cat_August_2010-4.jpg/1200px-Cat_August_2010-4.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/M101_hires_STScI-PRC2006-10a.jpg/1024px-M101_hires_STScI-PRC2006-10a.jpg', 'scenery')
    images(image_url, category);
    
  2. 创建 Unity Catalog Python UDTF 用于生成图像标题。

    1. 使用配置初始化 UDTF,包括批大小、Azure Databricks API 令牌、视觉模型终结点和工作区 URL。
    2. eval 方法中,将图像 URL 收集到缓冲区中。 当缓冲区达到批大小时,触发批处理。 这可确保在单个 API 调用中同时处理多个映像,而不是每个映像的单个调用。
    3. 在批处理方法中,下载所有缓冲图像,将其编码为 base64,并将其发送到 Databricks VisionModel 的单个 API 请求。 模型同时处理所有图像,并返回整个批处理的标题。
    4. 该方法 terminate 在每个分区的末尾完全执行一次。 在终止方法中,处理缓冲区中的任何剩余图像,并生成所有收集的字幕作为结果。

注释

用您的 Azure Databricks 实际工作区 URL(<workspace-url>)替换 https://your-workspace.cloud.databricks.com

CREATE OR REPLACE TEMPORARY FUNCTION batch_inference_image_caption(data TABLE, api_token STRING)
RETURNS TABLE (caption STRING)
LANGUAGE PYTHON
HANDLER 'BatchInferenceImageCaption'
COMMENT 'batch image captioning by sending groups of image URLs to a Databricks vision endpoint and returning concise captions for each image.'
AS $$
class BatchInferenceImageCaption:
    def __init__(self):
        self.batch_size = 3
        self.vision_endpoint = "databricks-claude-3-7-sonnet"
        self.workspace_url = "<workspace-url>"
        self.image_buffer = []
        self.results = []

    def eval(self, row, api_token):
        self.image_buffer.append((str(row[0]), api_token))
        if len(self.image_buffer) >= self.batch_size:
            self._process_batch()

    def terminate(self):
        if self.image_buffer:
            self._process_batch()
        for caption in self.results:
            yield (caption,)

    def _process_batch(self):
        batch_data = self.image_buffer.copy()
        self.image_buffer.clear()

        import base64
        import httpx
        import requests

        # API request timeout in seconds
        api_timeout = 60
        # Maximum tokens for vision model response
        max_response_tokens = 300
        # Temperature controls randomness (lower = more deterministic)
        model_temperature = 0.3

        # create a batch for the images
        batch_images = []
        api_token = batch_data[0][1] if batch_data else None

        for image_url, _ in batch_data:
            image_response = httpx.get(image_url, timeout=15)
            image_data = base64.standard_b64encode(image_response.content).decode("utf-8")
            batch_images.append(image_data)

        content_items = [{
            "type": "text",
            "text": "Provide brief captions for these images, one per line."
        }]
        for img_data in batch_images:
            content_items.append({
                "type": "image_url",
                "image_url": {
                    "url": "data:image/jpeg;base64," + img_data
                }
            })

        payload = {
            "messages": [{
                "role": "user",
                "content": content_items
            }],
            "max_tokens": max_response_tokens,
            "temperature": model_temperature
        }

        response = requests.post(
            self.workspace_url + "/serving-endpoints/" +
            self.vision_endpoint + "/invocations",
            headers={
                'Authorization': 'Bearer ' + api_token,
                'Content-Type': 'application/json'
            },
            json=payload,
            timeout=api_timeout
        )

        result = response.json()
        batch_response = result['choices'][0]['message']['content'].strip()

        lines = batch_response.split('\n')
        captions = [line.strip() for line in lines if line.strip()]

        while len(captions) < len(batch_data):
            captions.append(batch_response)

        self.results.extend(captions[:len(batch_data)])
$$;

若要使用批处理图像标题 UDTF,请使用示例图像表调用它:

注释

your_secret_scopeapi_token 替换为 Databricks API 令牌的实际机密范围和密钥名称。

SELECT
  caption
FROM
  batch_inference_image_caption(
    data => TABLE(sample_images),
    api_token => secret('your_secret_scope', 'api_token')
  )
+---------------------------------------------------------------------------------------------------------------+
| caption                                                                                                       |
+---------------------------------------------------------------------------------------------------------------+
| Wooden boardwalk cutting through vibrant wetland grasses under blue skies                                     |
| Black ant in detailed macro photography standing on a textured surface                                        |
| Tabby cat lounging comfortably on a white ledge against a white wall                                          |
| Stunning spiral galaxy with bright central core and sweeping blue-white arms against the black void of space. |
+---------------------------------------------------------------------------------------------------------------+

您还可以按照类别生成图像标题:

SELECT
  *
FROM
  batch_inference_image_caption(
    TABLE(sample_images)
    PARTITION BY category ORDER BY (category),
    secret('your_secret_scope', 'api_token')
  )
+------------------------------------------------------------------------------------------------------+
| caption                                                                                              |
+------------------------------------------------------------------------------------------------------+
| Black ant in detailed macro photography standing on a textured surface                               |
| Stunning spiral galaxy with bright center and sweeping blue-tinged arms against the black of space.  |
| Tabby cat lounging comfortably on white ledge against white wall                                     |
| Wooden boardwalk cutting through lush wetland grasses under blue skies                               |
+------------------------------------------------------------------------------------------------------+

示例:ML 模型评估的 ROC 曲线和 AUC 计算

此示例演示了如何使用 scikit-learn 计算接收机工作特性(ROC)曲线和曲线下面积(AUC)得分,以评估二分类模型。

此示例展示了几个重要的模式:

  • 外部库用法:集成 scikit-learn 进行 ROC 曲线计算
  • 有状态聚合:在计算指标之前累积所有行的预测
  • terminate() 方法用法:仅在评估所有行后处理完整的数据集并生成结果
  • 错误处理:验证输入表中是否存在所需的列

UDTF 使用 eval() 方法累积内存中的所有预测,然后在 terminate() 方法中计算并生成完整的 ROC 曲线。 此模式对于需要完整数据集进行计算的指标非常有用。

CREATE OR REPLACE TEMPORARY FUNCTION compute_roc_curve(t TABLE)
RETURNS TABLE (threshold DOUBLE, true_positive_rate DOUBLE, false_positive_rate DOUBLE, auc DOUBLE)
LANGUAGE PYTHON
HANDLER 'ROCCalculator'
COMMENT 'Compute ROC curve and AUC using scikit-learn'
AS $$
class ROCCalculator:
    def __init__(self):
        from sklearn import metrics
        self._roc_curve = metrics.roc_curve
        self._roc_auc_score = metrics.roc_auc_score

        self._true_labels = []
        self._predicted_scores = []

    def eval(self, row):
        if 'y_true' not in row or 'y_score' not in row:
            raise KeyError("Required columns 'y_true' and 'y_score' not found")

        true_label = row['y_true']
        predicted_score = row['y_score']

        label = float(true_label)
        self._true_labels.append(label)
        self._predicted_scores.append(float(predicted_score))

    def terminate(self):
        false_pos_rate, true_pos_rate, thresholds = self._roc_curve(
            self._true_labels,
            self._predicted_scores,
            drop_intermediate=False
        )

        auc_score = float(self._roc_auc_score(self._true_labels, self._predicted_scores))

        for threshold, tpr, fpr in zip(thresholds, true_pos_rate, false_pos_rate):
            yield float(threshold), float(tpr), float(fpr), auc_score
$$;

使用预测创建示例二进制分类数据:

CREATE OR REPLACE TEMPORARY VIEW binary_classification_data AS
SELECT *
FROM VALUES
  ( 1, 1.0, 0.95, 'high_confidence_positive'),
  ( 2, 1.0, 0.87, 'high_confidence_positive'),
  ( 3, 1.0, 0.82, 'medium_confidence_positive'),
  ( 4, 0.0, 0.78, 'false_positive'),
  ( 5, 1.0, 0.71, 'medium_confidence_positive'),
  ( 6, 0.0, 0.65, 'false_positive'),
  ( 7, 0.0, 0.58, 'true_negative'),
  ( 8, 1.0, 0.52, 'low_confidence_positive'),
  ( 9, 0.0, 0.45, 'true_negative'),
  (10, 0.0, 0.38, 'true_negative'),
  (11, 1.0, 0.31, 'low_confidence_positive'),
  (12, 0.0, 0.15, 'true_negative'),
  (13, 0.0, 0.08, 'high_confidence_negative'),
  (14, 0.0, 0.03, 'high_confidence_negative')
AS data(sample_id, y_true, y_score, prediction_type);

计算 ROC 曲线和 AUC:

SELECT
    threshold,
    true_positive_rate,
    false_positive_rate,
    auc
FROM compute_roc_curve(
  TABLE(
    SELECT y_true, y_score
    FROM binary_classification_data
    WHERE y_true IS NOT NULL AND y_score IS NOT NULL
    ORDER BY sample_id
  )
)
ORDER BY threshold DESC;
+-----------+---------------------+----------------------+-------+
| threshold | true_positive_rate  | false_positive_rate  | auc   |
+-----------+---------------------+----------------------+-------+
| 1.95      | 0.0                 | 0.0                  | 0.786 |
| 0.95      | 0.167               | 0.0                  | 0.786 |
| 0.87      | 0.333               | 0.0                  | 0.786 |
| 0.82      | 0.5                 | 0.0                  | 0.786 |
| 0.78      | 0.5                 | 0.125                | 0.786 |
| 0.71      | 0.667               | 0.125                | 0.786 |
| 0.65      | 0.667               | 0.25                 | 0.786 |
| 0.58      | 0.667               | 0.375                | 0.786 |
| 0.52      | 0.833               | 0.375                | 0.786 |
| 0.45      | 0.833               | 0.5                  | 0.786 |
| 0.38      | 0.833               | 0.625                | 0.786 |
| 0.31      | 1.0                 | 0.625                | 0.786 |
| 0.15      | 1.0                 | 0.75                 | 0.786 |
| 0.08      | 1.0                 | 0.875                | 0.786 |
| 0.03      | 1.0                 | 1.0                  | 0.786 |
+-----------+---------------------+----------------------+-------+

局限性

以下限制适用于 Unity 目录 Python UDDF:

后续步骤