深度学习模型推理工作流Deep learning model inference workflow

对于深度学习应用程序的模型推理,Azure Databricks 建议以下使用工作流。For model inference for deep learning applications, Azure Databricks recommends the following workflow. 如需详细了解如何对 Tensorflow 和 PyTorch 执行模型推理,请参阅模型推理示例For details about how to do model inference with Tensorflow and PyTorch, see the model inference examples.

  1. 将数据加载到 Spark 数据帧。Load the data into Spark DataFrames. 根据数据类型,Azure Databricks 建议使用以下方法来加载数据:Depending on the data type, Azure Databricks recommends the following ways to load data:

    • 图像文件(JPG、PNG):将图像路径加载到 Spark 数据帧。Image files (JPG,PNG): Load the image paths into a Spark DataFrame. 图像加载和预处理输入数据发生在 pandas UDF 中。Image loading and preprocessing input data occurs in a pandas UDF.
    files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])
    
    df = spark.read.format("tfrecords").load(image_path)
    
    • 数据源(如 Parquet、CSV、JSON、JDBC 和其他元数据):使用 Spark 数据源加载数据。Data sources such as Parquet, CSV, JSON, JDBC, and other metadata: Load the data using Spark data sources.
  2. 使用 pandas UDF 执行模型推理。Perform model inference using pandas UDFs. pandas UDF 使用 Apache Arrow 传输数据,使用 pandas 来处理数据。pandas UDFs use Apache Arrow to transfer data and pandas to work with the data. 若要进行模型推理,请遵循 pandas UDF 工作流中的主要步骤。To do model inference, the following are the broad steps in the workflow with pandas UDFs.

    1. 加载训练后的模型:为提高效率,Azure Databricks 建议从驱动程序广播模型的权重并加载模型图,然后从 pandas UDF 的广播变量中获得权重。Load the trained model: For efficiency, Azure Databricks recommends broadcasting the weights of the model from the driver and loading the model graph and get the weights from the broadcasted variables in a pandas UDF.
    2. 加载和预处理输入数据:若要批量加载数据,Azure Databricks 建议使用 tf.data API(针对 TensorFlow)和 DataLoader 类(针对 PyTorch)。Load and preprocess input data: To load data in batches, Azure Databricks recommends using the tf.data API for TensorFlow and the DataLoader class for PyTorch. 两者还支持预提取和多线程加载,以隐藏 IO 绑定延迟。Both also support prefetching and multi-threaded loading to hide IO bound latency.
    3. 运行模型预测:对数据批次运行模型推理。Run model prediction: run model inference on the data batch.
    4. 将预测结果发送回 Spark 数据帧:收集预测结果并作为 pd.Series 返回。Send predictions back to Spark DataFrames: collect the prediction results and return as pd.Series.