记录、加载、注册和部署 MLflow 模型

MLflow 模型是一种用于将机器学习模型打包的标准格式,可在多种不同下游工具(例如 Apache Spark 上的批量推理或通过 REST API 提供实时服务)中使用。 该格式定义了一种约定,利用这种约定,能够以不同的模型服务和推理平台可以理解的不同风格(python-function、pytorch、sklearn,等等)保存模型。

若要了解如何记录流模型并为其评分,请参阅如何保存和加载流式处理模型

记录和加载模型

当你记录模型时,MLflow 会自动记录 requirements.txtconda.yaml 文件。 可以使用这些文件重新创建模型开发环境并使用 virtualenv(推荐)或 conda 重新安装依赖项。

重要

Anaconda Inc. 为 anaconda.org 通道更新了其服务条款。 根据新的服务条款,如果依赖 Anaconda 的打包和分发,则可能需要商业许可证。 有关详细信息,请参阅 Anaconda Commercial Edition 常见问题解答。 对任何 Anaconda 通道的使用都受其服务条款的约束。

v1.18(Databricks Runtime 8.3 ML 或更低版本)之前记录的 MLflow 模型默认以 conda defaults 通道 (https://repo.anaconda.com/pkgs/) 作为依赖项进行记录。 由于此许可证更改,Databricks 已停止对使用 MLflow v1.18 及更高版本记录的模型使用 defaults 通道。 记录的默认通道现在为 conda-forge,它指向社区管理的 https://conda-forge.org/

如果在 MLflow v1.18 之前记录了一个模型,但没有从模型的 conda 环境中排除 defaults 通道,则该模型可能依赖于你可能没有预期到的 defaults 通道。 若要手动确认模型是否具有此依赖项,可以检查与记录的模型一起打包的 conda.yaml 文件中的 channel 值。 例如,具有 conda.yaml 通道依赖项的模型 defaults 可能如下所示:

channels:
- defaults
dependencies:
- python=3.8.8
- pip
- pip:
    - mlflow
    - scikit-learn==0.23.2
    - cloudpickle==1.6.0
      name: mlflow-env

由于 Databricks 无法确定是否允许你根据你与 Anaconda 的关系使用 Anaconda 存储库来与模型交互,因此 Databricks 不会强制要求其客户进行任何更改。 如果允许你根据 Anaconda 的条款通过 Databricks 使用 Anaconda.com 存储库,则你不需要采取任何措施。

若要更改模型环境中使用的通道,可以使用新的 conda.yaml 将模型重新注册到模型注册表。 为此,可以在 log_model()conda_env 参数中指定该通道。

有关 log_model() API 的详细信息,请参阅所用模型风格的 MLflow 文档,例如用于 scikit-learn 的 log_model

有关 conda.yaml 文件的详细信息,请参阅 MLflow 文档

API 命令

若要将模型记录到 MLflow 跟踪服务器,请使用 mlflow.<model-type>.log_model(model, ...)

若要加载先前记录的模型用于推理或进一步的开发,请使用 mlflow.<model-type>.load_model(modelpath),其中 modelpath 是以下项之一:

  • 运行相对路径(例如 runs:/{run_id}/{model-path}
  • DBFS 路径
  • 已注册的模型路径(例如 models:/{model_name}/{model_stage})。

有关用于加载 MLflow 模型的完整选项的列表,请参阅MLflow 文档中的“引用项目”

对于 Python MLflow 模型,一个附加选项是使用 mlflow.pyfunc.load_model() 将模型加载为泛型 Python 函数。 可使用以下代码片段来加载模型并为数据点评分。

model = mlflow.pyfunc.load_model(model_path)
model.predict(model_input)

或者,可将模型导出为 Apache Spark UDF 以用于在 Spark 群集上进行评分,或者导出为批处理作业或实时 Spark 流式处理作业。

# load input data table as a Spark DataFrame
input_data = spark.table(input_table_name)
model_udf = mlflow.pyfunc.spark_udf(spark, model_path)
df = input_data.withColumn("prediction", model_udf())

日志模型依赖项

若要准确加载模型,应确保模型依赖项以正确的版本加载到笔记本环境中。 在 Databricks Runtime 10.5 ML 及更高版本中,MLflow 会在检测到当前环境和模型的依赖项之间不匹配时发出警告。

Databricks Runtime 11.0 ML 及更高版本中包含用于简化还原模型依赖项的其他功能。 在 Databricks Runtime 11.0 ML 及更高版本中,对于 pyfunc 风格模型,可以调用 mlflow.pyfunc.get_model_dependencies 来检索和下载模型依赖项。 此函数返回依赖项文件的路径,然后可以使用 %pip install <file-path> 安装该文件。 将模型加载为 PySpark UDF 时,请在 mlflow.pyfunc.spark_udf 调用中指定 env_manager="virtualenv"。 这会在 PySpark UDF 的上下文中还原模型依赖项,并且不会影响外部环境。

在 Databricks Runtime 10.5 或更早版本中,还可通过手动安装 MLflow 版本 1.25.0 或更高版本来使用此功能:

%pip install "mlflow>=1.25.0"

有关如何记录模型依赖性(Python 和非 Python)和项目的更多信息,请参阅记录模型依赖性

了解如何为模型服务记录模型依赖项和自定义项目:

MLflow UI 中自动生成的代码片段

在 Azure Databricks 笔记本中记录模型时,Azure Databricks 会自动生成代码片段,你可以复制这些代码片段并使用它们来加载和运行模型。 若要查看这些代码片段,请执行以下操作:

  1. 导航到生成了模型的运行的“运行”屏幕。 (有关如何显示“运行”屏幕,请参阅查看笔记本试验。)
  2. 滚动到“项目”部分。
  3. 单击已记录的模型的名称。 右侧会打开一个面板,其中显示了可用于加载已记录的模型并在 Spark 或 Pandas 数据帧上进行预测的代码。

项目面板代码片段

示例

有关日志记录模型的示例,请参阅跟踪机器学习训练运行示例中的示例。 有关加载记录的模型进行推理的示例,请参阅模型推理示例

在模型注册表中注册模型

可以在 MLflow 模型注册表中注册模型,该注册表是一个集中的模型存储,它提供了 UI 和一组 API 来管理 MLflow 模型的完整生命周期。 有关如何使用模型注册表管理 Azure Databricks Unity Catalog 中的模型的说明,请参阅在 Unity Catalog 中管理模型生命周期。 若要使用工作区模型注册表,请参阅使用工作区模型注册表(旧版)管理模型生命周期

若要使用 API 注册模型,请使用 mlflow.register_model("runs:/{run_id}/{model-path}", "{registered-model-name}")

将模型保存到 DBFS

若要在本地保存模型,请使用 mlflow.<model-type>.save_model(model, modelpath)modelpath 必须是 modelpath 路径。 例如,如果使用 DBFS 位置 dbfs:/my_project_models 来存储项目工作,则必须使用模型路径 /dbfs/my_project_models

modelpath = "/dbfs/my_project_models/model-%f-%f" % (alpha, l1_ratio)
mlflow.sklearn.save_model(lr, modelpath)

对于 MLlib 模型,请使用 ML 管道

下载模型项目

可以使用各种 API 下载已注册的模型的已记录模型项目(例如模型文件、绘图和指标)。

Python API 示例:

from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository

model_uri = MlflowClient.get_model_version_download_uri(model_name, model_version)
ModelsArtifactRepository(model_uri).download_artifacts(artifact_path="")

Java API 示例:

MlflowClient mlflowClient = new MlflowClient();
// Get the model URI for a registered model version.
String modelURI = mlflowClient.getModelVersionDownloadUri(modelName, modelVersion);

// Or download the model artifacts directly.
File modelFile = mlflowClient.downloadModelVersion(modelName, modelVersion);

CLI 命令示例:

mlflow artifacts download --artifact-uri models:/<name>/<version|stage>

部署用于联机服务的模型

还可以使用 MLflow 的内置部署工具将模型部署到第三方服务框架。 请参阅以下示例。