运行 Batch Unity 目录 Python UDF 或 PySpark UDF 时,使用 TaskContext PySpark API 获取上下文信息。
例如,上下文信息(如用户的标识和群集标记)可以验证用户的标识以访问外部服务。
要求
- Databricks Runtime 版本 16.3 及更高版本支持 TaskContext。 
- 以下 UDF 类型支持 TaskContext: 
使用 TaskContext 获取上下文信息
选择一个选项卡以查看 PySpark UDF 或 Batch Unity 目录 Python UDF 的 TaskContext 示例。
PySpark UDF
以下 PySpark UDF 示例打印用户的上下文:
@udf
def log_context():
  import json
  from pyspark.taskcontext import TaskContext
  tc = TaskContext.get()
  # Returns current user executing the UDF
  session_user = tc.getLocalProperty("user")
  # Returns cluster tags
  tags = dict(item.values() for item in json.loads(tc.getLocalProperty("spark.databricks.clusterUsageTags.clusterAllTags  ") or "[]"))
  # Returns current version details
  current_version = {
    "dbr_version": tc.getLocalProperty("spark.databricks.clusterUsageTags.sparkVersion"),
    "dbsql_version": tc.getLocalProperty("spark.databricks.clusterUsageTags.dbsqlVersion")
  }
  return {
    "user": session_user,
    "job_group_id": job_group_id,
    "tags": tags,
    "current_version": current_version
  }
Batch Unity 目录 Python UDF
以下 Batch Unity Catalog Python UDF 示例通过服务凭证获取用户身份,并调用 AWS Lambda 函数:
%sql
CREATE OR REPLACE FUNCTION main.test.call_lambda_func(data STRING, debug BOOLEAN) RETURNS STRING LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
  `batch-udf-service-creds-example-cred` DEFAULT
)
AS $$
import boto3
import json
import pandas as pd
import base64
from pyspark.taskcontext import TaskContext
def batchhandler(it):
  # Automatically picks up DEFAULT credential:
  session = boto3.Session()
  client = session.client("lambda", region_name="us-west-2")
  # Can propagate TaskContext information to lambda context:
  user_ctx = {"custom": {"user": TaskContext.get().getLocalProperty("user")}}
  for vals, is_debug in it:
    payload = json.dumps({"values": vals.to_list(), "is_debug": bool(is_debug[0])})
    res = client.invoke(
      FunctionName="HashValuesFunction",
      InvocationType="RequestResponse",
      ClientContext=base64.b64encode(json.dumps(user_ctx).encode("utf-8")).decode(
        "utf-8"
      ),
      Payload=payload,
    )
    response_payload = json.loads(res["Payload"].read().decode("utf-8"))
    if "errorMessage" in response_payload:
      raise Exception(str(response_payload))
    yield pd.Series(response_payload["values"])
$$;
注册 UDF 后调用它:
SELECT main.test.call_lambda_func(data, false)
FROM VALUES
('abc'),
('def')
AS t(data)
TaskContext 属性
该方法 TaskContext.getLocalProperty() 具有以下属性键:
| 属性键 | 说明 | 示例用法 | 
|---|---|---|
| user | 当前正在执行 UDF 的用户 | tc.getLocalProperty("user")-> "alice" | 
| spark.jobGroup.id | 与当前 UDF 关联的 Spark 作业组 ID | tc.getLocalProperty("spark.jobGroup.id")-> "jobGroup-92318" | 
| spark.databricks.clusterUsageTags.clusterAllTags | 将元数据标签作为键值对,按照 JSON 字典格式的字符串形式进行分组。 | tc.getLocalProperty("spark.databricks.clusterUsageTags.clusterAllTags")-> [{"Department": "Finance"}] | 
| spark.databricks.clusterUsageTags.region | 工作区所在的区域 | tc.getLocalProperty("spark.databricks.clusterUsageTags.region")-> "us-west-2" | 
| accountId | 正在运行的上下文的 Databricks 帐户 ID | tc.getLocalProperty("accountId")-> "1234567890123456" | 
| orgId | 工作区 ID (DBSQL 上不可用) | tc.getLocalProperty("orgId")-> "987654321" | 
| spark.databricks.clusterUsageTags.sparkVersion | 群集的 Databricks Runtime 版本(在非 DBSQL 环境中) | tc.getLocalProperty("spark.databricks.clusterUsageTags.sparkVersion")-> "16.3" | 
| spark.databricks.clusterUsageTags.dbsqlVersion | DBSQL 版本(在 DBSQL 环境中) | tc.getLocalProperty("spark.databricks.clusterUsageTags.dbsqlVersion")-> "2024.35" |