Compartir a través de

使用 MLflow 跟踪基于 Git 的应用程序的版本

本指南演示如何在应用的代码驻留在 Git 或类似的版本控制系统中时跟踪 GenAI 应用程序的版本。 在此工作流中,MLflow LoggedModel 充当 元数据中心,将每个概念应用程序版本链接到其特定的外部代码(例如 Git 提交),配置。 然后,此 LoggedModel 可以与跟踪和评估运行等 MLflow 实体相关联。

mlflow.set_active_model(name=...) 是版本跟踪的关键:调用此函数会将应用程序的跟踪链接到 LoggedModelname如果不存在,则会自动创建一个新LoggedModel项。

本指南介绍以下内容:

  • 使用 LoggedModels 跟踪您的应用程序版本。
  • 将评估运行链接到 LoggedModel

小窍门

Databricks 建议 LoggedModels 与 MLflow 的提示注册表一起使用。 如果使用提示注册表,则每个提示的版本都会自动关联到你的 LoggedModel版本。 请参阅跟踪提示版本和应用程序版本

先决条件

  1. 安装 MLflow 和所需包

    pip install --upgrade "mlflow[databricks]>=3.1.0" openai
    
  2. 请按照设置环境快速入门创建 MLflow 试验。

步骤 1:创建示例应用程序

以下代码创建一个简单的应用程序,用于提示 LLM 进行响应。

  1. 初始化 OpenAI 客户端以连接到由 Databricks 托管的 LLM 或者由 OpenAI 托管的 LLM。

    Databricks 托管的 LLM

    使用 MLflow 获取一个 OpenAI 客户端,以连接到由 Databricks 托管的 LLMs。 从可用的基础模型中选择一个模型。

    import mlflow
    from databricks.sdk import WorkspaceClient
    
    # Enable MLflow's autologging to instrument your application with Tracing
    mlflow.openai.autolog()
    
    # Set up MLflow tracking to Databricks
    mlflow.set_tracking_uri("databricks")
    mlflow.set_experiment("/Shared/docs-demo")
    
    # Create an OpenAI client that is connected to Databricks-hosted LLMs
    w = WorkspaceClient()
    client = w.serving_endpoints.get_open_ai_client()
    
    # Select an LLM
    model_name = "databricks-claude-sonnet-4"
    

    OpenAI 托管的 LLM

    使用本地 OpenAI SDK 连接到由 OpenAI 托管的模型。 从 可用的 OpenAI 模型中选择一个模型。

    import mlflow
    import os
    import openai
    
    # Ensure your OPENAI_API_KEY is set in your environment
    # os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>" # Uncomment and set if not globally configured
    
    # Enable auto-tracing for OpenAI
    mlflow.openai.autolog()
    
    # Set up MLflow tracking to Databricks
    mlflow.set_tracking_uri("databricks")
    mlflow.set_experiment("/Shared/docs-demo")
    
    # Create an OpenAI client connected to OpenAI SDKs
    client = openai.OpenAI()
    
    # Select an LLM
    model_name = "gpt-4o-mini"
    
  2. 创建示例应用程序:

    # Use the trace decorator to capture the application's entry point
    @mlflow.trace
    def my_app(input: str):
        # This call is automatically instrumented by `mlflow.openai.autolog()`
        response = client.chat.completions.create(
            model=model_name,  # This example uses a Databricks hosted LLM - you can replace this with any AI Gateway or Model Serving endpoint. If you provide your own OpenAI credentials, replace with a valid OpenAI model e.g., gpt-4o, etc.
            messages=[
                {
                    "role": "system",
                    "content": "You are a helpful assistant.",
                },
                {
                    "role": "user",
                    "content": input,
                },
            ],
        )
        return response.choices[0].message.content
    
    result = my_app(input="What is MLflow?")
    print(result)
    

步骤 2:向应用代码添加版本跟踪

LoggedModel 版本充当应用程序特定版本的中央记录(元数据中心)。 它不需要存储应用程序代码本身。 而是指向代码托管的位置(例如 Git 提交哈希)。

使用 mlflow.set_active_model() 来声明你当前正在使用的 LoggedModel 或创建一个新的。 此函数返回一个 ActiveModel 对象,其中包含 model_id,此内容在后续操作中非常有用。

小窍门

在生产环境中,可以设置环境变量 MLFLOW_ACTIVE_MODEL_ID ,而不是调用 set_active_model()。 请参阅 生产指南中的版本跟踪

注释

以下代码使用当前的 Git 提交哈希作为模型的名称,因此模型版本仅在提交时递增。 若要为代码库中的每个更改创建新的 LoggedModel,请参阅 帮助程序函数 ,该函数为代码库中的任何更改创建唯一的 LoggedModel,即使未提交到 Git 也是如此。

在步骤 1 中的应用程序顶部插入以下代码。 在应用程序中,必须调用 set_active_model() 在执行应用代码之前

# Keep original imports
### NEW CODE
import subprocess

# Define your application and its version identifier
app_name = "customer_support_agent"

# Get current git commit hash for versioning
try:
    git_commit = (
        subprocess.check_output(["git", "rev-parse", "HEAD"])
        .decode("ascii")
        .strip()[:8]
    )
    version_identifier = f"git-{git_commit}"
except subprocess.CalledProcessError:
    version_identifier = "local-dev"  # Fallback if not in a git repo
logged_model_name = f"{app_name}-{version_identifier}"

# Set the active model context
active_model_info = mlflow.set_active_model(name=logged_model_name)
print(
    f"Active LoggedModel: '{active_model_info.name}', Model ID: '{active_model_info.model_id}'"
)

### END NEW CODE

### ORIGINAL CODE BELOW
### ...

步骤 3:(可选) 记录参数

可以使用 LoggedModel 将定义此版本应用程序的关键配置参数直接记录到 mlflow.log_model_params()。 这可用于记录绑定到此代码版本的 LLM 名称、温度设置或检索策略等内容。

在步骤 3 中的代码下方添加以下代码:

app_params = {
    "llm": "gpt-4o-mini",
    "temperature": 0.7,
    "retrieval_strategy": "vector_search_v3",
}

# Log params
mlflow.log_model_params(model_id=active_model_info.model_id, params=app_params)

步骤 4:运行应用程序

  1. 调用应用程序以查看 LoggedModel 的创建和跟踪方式。
# These 2 invocations will be linked to the same LoggedModel
result = my_app(input="What is MLflow?")
print(result)

result = my_app(input="What is Databricks?")
print(result)
  1. 若要在不提交的情况下模拟更改,请添加以下行以手动创建新的日志模型。

# Set the active model context
active_model_info = mlflow.set_active_model(name="new-name-set-manually")
print(
    f"Active LoggedModel: '{active_model_info.name}', Model ID: '{active_model_info.model_id}'"
)

app_params = {
    "llm": "gpt-4o",
    "temperature": 0.7,
    "retrieval_strategy": "vector_search_v4",
}

# Log params
mlflow.log_model_params(model_id=active_model_info.model_id, params=app_params)

# This will create a new LoggedModel
result = my_app(input="What is GenAI?")
print(result)

步骤 5:查看与 LoggedModel 相关的跟踪

使用用户界面(UI)

转到 MLflow 试验 UI。 在“ 跟踪 ”选项卡中,可以看到生成每个跟踪的应用版本(请注意,第一个跟踪不会附加版本,因为我们调用了应用而不先调用 set_active_model() 该应用)。 在“ 版本 ”选项卡中,可以看到每个 LoggedModel 参数和链接跟踪。

生成每个踪迹的版本

使用 SDK

可以使用 search_traces() 查询来自 LoggedModel 的跟踪:

import mlflow

traces = mlflow.search_traces(
    filter_string=f"metadata.`mlflow.modelId` = '{active_model_info.model_id}'"
)
print(traces)

您可以使用get_logged_model()获取LoggedModel的详细信息:

import mlflow
import datetime
# Get LoggedModel metadata
logged_model = mlflow.get_logged_model(model_id=active_model_info.model_id)

# Inspect basic properties
print(f"\n=== LoggedModel Information ===")
print(f"Model ID: {logged_model.model_id}")
print(f"Name: {logged_model.name}")
print(f"Experiment ID: {logged_model.experiment_id}")
print(f"Status: {logged_model.status}")
print(f"Model Type: {logged_model.model_type}")
creation_time = datetime.datetime.fromtimestamp(logged_model.creation_timestamp / 1000)
print(f"Created at: {creation_time}")

# Access the parameters
print(f"\n=== Model Parameters ===")
for param_name, param_value in logged_model.params.items():
    print(f"{param_name}: {param_value}")

# Access tags if any were set
if logged_model.tags:
    print(f"\n=== Model Tags ===")
    for tag_key, tag_value in logged_model.tags.items():
        print(f"{tag_key}: {tag_value}")

若要评估应用程序并将结果链接到此 LoggedModel 版本,请参阅 将评估结果和跟踪链接到应用版本。 本指南介绍如何用于 mlflow.genai.evaluate() 评估应用程序的性能,并自动将指标、评估表和跟踪与特定 LoggedModel 版本相关联。

import mlflow
from mlflow.genai import scorers

eval_dataset = [
    {
        "inputs": {"input": "What is the most common aggregate function in SQL?"},
    }
]

mlflow.genai.evaluate(data=eval_dataset, predict_fn=my_app, model_id=active_model_info.model_id, scorers=scorers.get_all_scorers())

在 MLflow 试验 UI 的 “版本评估 ”选项卡中查看结果:

“版本和评估”选项卡

用于计算任何文件更改的唯一哈希的帮助程序函数

下面的帮助程序函数根据存储库的状态自动生成每个 LoggedModel 的名称。 若要使用此函数,请调用 set_active_model(name=get_current_git_hash())

get_current_git_hash() 通过返回 HEAD 提交哈希(用于清理存储库)或 HEAD 哈希和未提交的更改哈希(对于脏存储库)的组合,为 Git 存储库的当前状态生成唯一确定性标识符。 它确保存储库的不同状态始终生成不同的标识符,因此每个代码更改都会导致新的 LoggedModel

import subprocess
import hashlib
import os

def get_current_git_hash():
    """
    Get a deterministic hash representing the current git state.
    For clean repositories, returns the HEAD commit hash.
    For dirty repositories, returns a combination of HEAD + hash of changes.
    """
    try:
        # Get the git repository root
        result = subprocess.run(
            ["git", "rev-parse", "--show-toplevel"],
            capture_output=True, text=True, check=True
        )
        git_root = result.stdout.strip()

        # Get the current HEAD commit hash
        result = subprocess.run(
            ["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True
        )
        head_hash = result.stdout.strip()

        # Check if repository is dirty
        result = subprocess.run(
            ["git", "status", "--porcelain"], capture_output=True, text=True, check=True
        )

        if not result.stdout.strip():
            # Repository is clean, return HEAD hash
            return head_hash

        # Repository is dirty, create deterministic hash of changes
        # Collect all types of changes
        changes_parts = []

        # 1. Get staged changes
        result = subprocess.run(
            ["git", "diff", "--cached"], capture_output=True, text=True, check=True
        )
        if result.stdout:
            changes_parts.append(("STAGED", result.stdout))

        # 2. Get unstaged changes to tracked files
        result = subprocess.run(
            ["git", "diff"], capture_output=True, text=True, check=True
        )
        if result.stdout:
            changes_parts.append(("UNSTAGED", result.stdout))

        # 3. Get all untracked/modified files from status
        result = subprocess.run(
            ["git", "status", "--porcelain", "-uall"],
            capture_output=True, text=True, check=True
        )

        # Parse status output to handle all file states
        status_lines = result.stdout.strip().split('\n') if result.stdout.strip() else []
        file_contents = []

        for line in status_lines:
            if len(line) >= 3:
                status_code = line[:2]
                filepath = line[3:]  # Don't strip - filepath starts exactly at position 3

                # For any modified or untracked file, include its current content
                if '?' in status_code or 'M' in status_code or 'A' in status_code:
                    try:
                        # Use absolute path relative to git root
                        abs_filepath = os.path.join(git_root, filepath)
                        with open(abs_filepath, 'rb') as f:
                            # Read as binary to avoid encoding issues
                            content = f.read()
                            # Create a hash of the file content
                            file_hash = hashlib.sha256(content).hexdigest()
                            file_contents.append(f"{filepath}:{file_hash}")
                    except (IOError, OSError):
                        file_contents.append(f"{filepath}:unreadable")

        # Sort file contents for deterministic ordering
        file_contents.sort()

        # Combine all changes
        all_changes_parts = []

        # Add diff outputs
        for change_type, content in changes_parts:
            all_changes_parts.append(f"{change_type}:\n{content}")

        # Add file content hashes
        if file_contents:
            all_changes_parts.append("FILES:\n" + "\n".join(file_contents))

        # Create final hash
        all_changes = "\n".join(all_changes_parts)
        content_to_hash = f"{head_hash}\n{all_changes}"
        changes_hash = hashlib.sha256(content_to_hash.encode()).hexdigest()

        # Return HEAD hash + first 8 chars of changes hash
        return f"{head_hash[:32]}-dirty-{changes_hash[:8]}"

    except subprocess.CalledProcessError as e:
        raise RuntimeError(f"Git command failed: {e}")
    except FileNotFoundError:
        raise RuntimeError("Git is not installed or not in PATH")

后续步骤