深度学习模型推理工作流
对于深度学习应用程序的模型推理,Azure Databricks 建议以下使用工作流。 有关使用 TensorFlow 和 PyTorch 的示例笔记本,请参阅深度学习模型推理示例。
将数据加载到 Spark 数据帧。 根据数据类型,Azure Databricks 建议使用以下方法来加载数据:
- 图像文件(JPG、PNG):将图像路径加载到 Spark 数据帧中。 图像加载和预处理输入数据发生在 pandas UDF 中。
files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])
- TFRecord:使用 spark-tensorflow-connector 加载数据。
df = spark.read.format("tfrecords").load(image_path)
- 数据源(如 Parquet、CSV、JSON、JDBC 和其他元数据):使用 Spark 数据源加载数据。
使用 pandas UDF 执行模型推理。pandas UDF 使用 Apache Arrow 传输数据,并使用 pandas 处理数据。 若要进行模型推理,请遵循 pandas UDF 工作流中的主要步骤。
- 加载已训练的模型:为提高效率,Azure Databricks 建议从驱动程序广播模型的权重并加载模型图,然后从 pandas UDF 的广播变量中获得权重。
- 加载和预处理输入数据:若要批量加载数据,Azure Databricks 建议使用 tf.data API(针对 TensorFlow)和 DataLoader 类(针对 PyTorch)。 两者还支持预提取和多线程加载,以隐藏 IO 绑定延迟。
- 运行模型预测:对数据批次运行模型推理。
- 将预测结果发送回 Spark 数据帧:收集预测结果并作为
pd.Series
返回。
深度学习模型推理示例
本节中的示例遵循推荐的深度学习推理工作流。 以下示例演示了如何使用预先训练的深层残差网络 (ResNets) 神经网络模型执行模型推理。