本文介绍如何使用 Mosaic 流式处理将数据从 Apache Spark 转换为与 PyTorch 兼容的格式。
Mosaic Streaming 是一个开源数据流加载库。 它可直接从已加载为 Apache Spark 数据帧的数据集对深度学习模型进行单节点或分布式训练和评估。 Mosaic Streaming 主要支持 Mosaic Composer,并且还与 PyTorch、PyTorch Lightning 和 TorchDistributor 集成。 与传统 PyTorch DataLoaders 相比,Mosaic 流媒体具备多项优势,包括:
与任何数据类型(包括图像、文本、视频和多模式数据)的兼容性。
支持云存储服务提供商(AWS、OCI、Azure、Databricks UC Volume,以及任何与 S3 兼容的对象存储,例如 Cloudflare R2、Coreweave、Backblaze b2 等)
最大限度地保证正确性,以及最大程度地提升性能、灵活性和易用性。 有关详细信息,请查看相应的主要功能页。
有关 Mosaic 流式处理的一般信息,请查看流式处理 API 文档。
注意
Mosaic Streaming 已预安装在 Databricks Runtime 15.2 ML 及更高版本中。
使用 Mosaic Streaming 从 Spark 数据帧加载数据
Mosaic Streaming 提供了一个简单的工作流,可以将 Apache Spark 转换为 Mosaic 数据分片 (MDS) 格式,然后在分布式环境中加载使用。
建议的工作流为:
使用 Apache Spark 来加载数据,还可以选择对数据进行预处理。
使用
streaming.base.converters.dataframe_to_mds
将数据帧保存到磁盘进行暂时存储和/或保存到 Unity Catalog 卷进行持久存储。 此数据将以 MDS 格式存储,并且可以通过对压缩和哈希的支持进行进一步优化。 高级用例还可以包括使用 UDF 对数据进行预处理。 有关详细信息,请查看将 Spark 数据帧转换为 MDS 的教程。使用
streaming.StreamingDataset
将必要的数据加载到内存中。StreamingDataset
是 PyTorch 中 IterableDataset 的一个版本,它支持灵活的确定性随机,可快速从中断的位置恢复。 有关详细信息,请查看 StreamingDataset 文档。使用
streaming.StreamingDataLoader
加载训练/评估/测试所需的数据。StreamingDataLoader
是 PyTorch 中 DataLoader 的一个版本,它提供额外的检查点/恢复接口,可用于跟踪此设置级别中模型已处理的样本数量。
如需了解端到端示例,请查看以下笔记本:
使用 Mosaic Streaming 笔记本简化从 Spark 到 PyTorch 的数据加载
疑难解答:身份验证错误
如果在使用 StreamingDataset
从 Unity Catalog 卷加载数据时出现以下错误,请按如下所示设置环境变量。
ValueError: default auth: cannot configure default credentials, please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication to configure credentials for your preferred authentication method.
注意
如果使用 TorchDistributor
运行分布式训练时看到此错误,还必须在工作器节点上设置环境变量。
db_host = "https://your-databricks-host.databricks.com"
db_token = "YOUR API TOKEN" # Create a token with either method from https://docs.databricks.com/en/dev-tools/auth/index.html#databricks-authentication-methods
def your_training_function():
import os
os.environ['DATABRICKS_HOST'] = db_host
os.environ['DATABRICKS_TOKEN'] = db_token
# The above function can be distributed with TorchDistributor:
# from pyspark.ml.torch.distributor import TorchDistributor
# distributor = TorchDistributor(...)
# distributor.run(your_training_function)