使用声明性功能训练模型

重要

此功能在 Beta 版中。 工作区管理员可以从 预览 页控制对此功能的访问。 请参阅 Manage Azure Databricks 预览版

本页介绍如何使用声明性功能进行模型训练。 有关定义声明性功能的信息,请参阅 声明性功能

要求

  • 必须使用声明性功能 API 创建功能。 请参阅 声明性功能

API 方法

create_training_set()

创建声明性功能后,下一步是为模型创建训练数据。 为此,请将标记的数据集传递给create_training_set,这会自动确保每个特征值的计算在时间点上准确无误。

例如:

FeatureEngineeringClient.create_training_set(
    df: DataFrame,                                # DataFrame with training data
    features: Optional[List[Feature]],            # List of Feature objects
    label: Union[str, List[str], None],           # Label column name(s)
    exclude_columns: Optional[List[str]] = None,  # Optional: columns to exclude
) -> TrainingSet

调用 TrainingSet.load_df 将原始训练数据与时间点动态计算特征联接。

df 参数必须满足以下要求:

  • 必须包含功能定义引用的所有实体列。
  • 必须包含特征定义中引用的时序列列。
  • 必须包含在任何 RequestSource 架构中声明的所有列。 根据声明的模式规范对类型进行验证——当出现不匹配时,会引发错误(无隐式强制转换)。
  • 应包含一个或多个标签列。
  • 实体列名、时间系列列名称和请求功能列名称集在所有源中必须全局唯一。

时间点正确性: 对于表源支持的聚合和 ColumnSelection 功能,特征仅使用每行时间戳之前可用的源数据进行计算,以防止将来的数据泄漏到模型训练中。 对于 RequestSource 功能,该值直接从标记的 DataFrame 行获取。

log_model()

使用 MLflow 记录具有特征元数据的模型,以便进行溯源跟踪和推理时自动特征查找。

FeatureEngineeringClient.log_model(
    model,                                    # Trained model object
    artifact_path: str,                       # Path to store model artifact
    flavor: ModuleType,                       # MLflow flavor module (e.g., mlflow.sklearn)
    training_set: TrainingSet,                # TrainingSet used for training
    registered_model_name: Optional[str],     # Optional: register model in Unity Catalog
)

flavor 参数指定要使用的 MLflow 模型风格 模块,例如 mlflow.sklearnmlflow.xgboost

使用TrainingSet记录的模型会自动跟踪训练中使用的特征谱系。 训练集包含 RequestSource 特征时,这些 RequestSource 列将作为所需的输入添加到 MLflow 模型签名中。 这可确保服务终结点的 API 架构反映调用方在推理时必须提供的字段。 有关详细信息,请参阅 使用特征表训练模型

score_batch()

使用自动特征查找执行批量推理

FeatureEngineeringClient.score_batch(
    model_uri: str,                           # URI of logged model
    df: DataFrame,                            # DataFrame with entity keys and timestamps
) -> DataFrame

score_batch 使用与模型一起存储的特征元数据,自动计算时间点上正确的特征以进行推理,从而确保与训练的一致性。 有关详细信息,请参阅 使用特征表训练模型

示例工作流

import mlflow
from databricks.feature_engineering import FeatureEngineeringClient
from sklearn.ensemble import RandomForestClassifier

fe = FeatureEngineeringClient()

# Assume features are registered in UC
# labeled_df should have columns "user_id", "transaction_time", and "is_fraud"

# 1. Create training set using declarative features
training_set = fe.create_training_set(
    df=labeled_df,
    features=features,
    label="is_fraud",
)

# 2. Load training data with computed features
training_df = training_set.load_df()
X = training_df.drop("is_fraud").toPandas()
y = training_df.select("is_fraud").toPandas().values.ravel()

# 3. Train model
model = RandomForestClassifier().fit(X, y)

# 4. Log model with feature metadata
with mlflow.start_run():
    fe.log_model(
        model=model,
        artifact_path="fraud_model",
        flavor=mlflow.sklearn,
        training_set=training_set,
        registered_model_name="main.ecommerce.fraud_model",
    )

# 5. Batch scoring with automatic feature lookup
# inference_df must contain the same entity and timeseries columns
# used during training. Features are automatically computed.
predictions = fe.score_batch(
    model_uri="models:/main.ecommerce.fraud_model/1",
    df=inference_df,
)
predictions.display()

使用 RequestSource 功能进行训练

如果模型需要在推理时提供的数据(例如 API 调用中的事务详细信息),请使用RequestSource特性以及表支持的特性。 在训练期间,将从标记的数据帧中提取 RequestSource 列。

from databricks.feature_engineering import FeatureEngineeringClient
from databricks.feature_engineering.entities import (
    DeltaTableSource, Feature, FieldDefinition, RequestSource,
    ScalarDataType, ColumnSelection,
)

fe = FeatureEngineeringClient()

# RequestSource provides transaction data at inference time
request_source = RequestSource(
    schema=[
        FieldDefinition(name="transaction_amount", data_type=ScalarDataType.DOUBLE),
        FieldDefinition(name="vendor_id", data_type=ScalarDataType.STRING),
        FieldDefinition(name="transaction_id", data_type=ScalarDataType.STRING),
        FieldDefinition(name="transaction_time", data_type=ScalarDataType.DATE),
    ]
)

delta_source = DeltaTableSource(
    catalog_name="catalog",
    schema_name="schema",
    table_name="vendor_data",
)

# A column selection feature from the request source (pass-through)
latest_transaction_amount = Feature(
    source=request_source,
    function=ColumnSelection("transaction_amount"),
    name="latest_transaction_amount",
)

# A lookup feature from a delta table
vendor_category = Feature(
    source=delta_source,
    function=ColumnSelection("vendor_category"),
    entity=["vendor_id"],
    timeseries_column="transaction_time",
    name="vendor_category",
)

# labels_df must contain: transaction_id, transaction_time, vendor_id,
# transaction_amount, and the label column.
ts = fe.create_training_set(
    df=labels_df,
    features=[latest_transaction_amount, vendor_category],
    label="is_fraud",
    exclude_columns=["card_id"],
)

import mlflow
from sklearn.ensemble import RandomForestClassifier

with mlflow.start_run():
    training_df = ts.load_df().toPandas()
    X = training_df.drop(columns=["is_fraud"])
    y = training_df["is_fraud"]
    model = RandomForestClassifier().fit(X, y)

    # log_model() adds RequestSource columns to the MLflow model signature
    fe.log_model(
        model=model,
        artifact_path="fraud_model",
        flavor=mlflow.sklearn,
        training_set=ts,
        registered_model_name="catalog.schema.fraud_model",
    )

在服务时输入到原始模型的内容

功能存储的模型包装器会在将列传递给原始模型之前,先过滤这些列。

列类型 到达内部模型?
显式特征输出(ColumnSelection,聚合)
RequestSource 声明为特征的列
实体列(查找键) 否(除非显式声明为特征)
时间序列列 否(除非显式声明为特征)