使用 TorchDistributor 进行分布式训练

本文介绍如何使用 TorchDistributor 在 PyTorch ML 模型上执行分布式训练。

TorchDistributor 是 PySpark 中的一个开源模块,可帮助用户在其 Spark 群集上使用 PyTorch 进行分布式训练,因此它允许你将 PyTorch 训练作业作为 Spark 作业启动。 在后台,它会初始化环境,并会初始化辅助角色之间的信道,同时利用 CLI 命令 torch.distributed.run 在工作器节点之间运行分布式训练。

TorchDistributor API 支持下表中显示的方法。

方法和签名 说明
init(self, num_processes, local_mode, use_gpu) 创建 TorchDistributor 的实例。
run(self, main, *args) 如果 main 是一个函数,则通过调用 main(**kwargs) 运行分布式训练;如果 main 是一个文件路径,则运行 CLI 命令 torchrun main *args

要求

  • Spark 3.4
  • Databricks Runtime 13.0 ML 或更高版本

笔记本的开发工作流

如果模型创建和训练过程完全通过本地计算机上的笔记本或 Databricks 笔记本进行,则你只需做出少量的更改即可让代码为分布式训练做好准备。

  1. 准备单节点代码:使用 PyTorch、PyTorch Lightning 或其他基于 PyTorch/PyTorch Lightning 的框架(例如 HuggingFace Trainer API)准备和测试单节点代码。

  2. 为标准分布式训练准备代码:需要将单进程训练转换为分布式训练。 将此分布式代码全部包含在一个可与 TorchDistributor 结合使用的训练函数中。

  3. 将 import 语句移入训练函数:在训练函数中添加所需的 import 语句,例如 import torch。 这样可以避免常见的格式转换错误。 此外,模型和数据关联到的 device_id 由以下因素决定:

    device_id = int(os.environ["LOCAL_RANK"])
    
  4. 启动分布式训练:使用所需的参数实例化 TorchDistributor,并调用 .run(*args) 启动训练。

下面是一个训练代码示例:

from pyspark.ml.torch.distributor import TorchDistributor

def train(learning_rate, use_gpu):
  import torch
  import torch.distributed as dist
  import torch.nn.parallel.DistributedDataParallel as DDP
  from torch.utils.data import DistributedSampler, DataLoader

  backend = "nccl" if use_gpu else "gloo"
  dist.init_process_group(backend)
  device = int(os.environ["LOCAL_RANK"]) if use_gpu  else "cpu"
  model = DDP(createModel(), **kwargs)
  sampler = DistributedSampler(dataset)
  loader = DataLoader(dataset, sampler=sampler)

  output = train(model, loader, learning_rate)
  dist.cleanup()
  return output

distributor = TorchDistributor(num_processes=2, local_mode=False, use_gpu=True)
distributor.run(train, 1e-3, True)

从外部存储库迁移训练

如果你在外部存储库中存储了一个现有的分布式训练过程,可以通过执行以下操作轻松迁移到 Azure Databricks:

  1. 导入存储库:将外部存储库导入为 Databricks Git 文件夹

  2. 创建新笔记本:在存储库中初始化新的 Azure Databricks 笔记本。

  3. 启动分布式训练:在笔记本单元格中,如下所示调用 TorchDistributor

    from pyspark.ml.torch.distributor import TorchDistributor
    
    train_file = "/path/to/train.py"
    args = ["--learning_rate=0.001", "--batch_size=16"]
    distributor = TorchDistributor(num_processes=2, local_mode=False, use_gpu=True)
    distributor.run(train_file, *args)
    

故障排除

笔记本工作流发生的一个常见错误是在运行分布式训练时找不到对象,或对象的格式已转换。 未向其他执行器分发库 import 语句时,可能会发生此错误。

为避免此问题,请在使用 TorchDistributor(...).run(<func>) 调用的训练函数的顶部,以及在训练方法中调用的其他任何用户定义函数内部包含所有 import 语句(例如 import torch)。

示例笔记本

以下笔记本示例演示如何使用 PyTorch 执行分布式训练。

在 Databricks 笔记本中进行端到端分布式训练

获取笔记本

分布式优化 Hugging Face 模型笔记本

获取笔记本

在 PyTorch 文件笔记本中进行分布式训练

获取笔记本

使用 PyTorch Lightning 笔记本进行分布式训练

获取笔记本

使用 Petastorm 笔记本加载分布式数据

获取笔记本