使用 sparkdl.xgboost 对 XGBoost 模型进行分布式训练

重要

此功能目前以公共预览版提供。

注意

sparkdl.xgboost 从 Databricks Runtime 12.0 ML 开始已被弃用,在 Databricks Runtime 13.0 ML 及更高版本中已被删除。 若要了解如何将工作负荷迁移到 xgboost.spark,请参阅已弃用的 sparkdl.xgboost 模块的迁移指南

Databricks Runtime ML 包含基于 Python xgboostsparkdl.xgboost.XgboostRegressorsparkdl.xgboost.XgboostClassifier 的 PySpark 估算器。 可以基于这些估算器创建 ML 管道。 有关详细信息,请参阅适用于 PySpark 管道的 XGBoost

Databricks 强烈建议 sparkdl.xgboost 用户使用 Databricks Runtime 11.3 LTS ML 或更高版本。 旧版 Databricks Runtime 受旧版 sparkdl.xgboost 中的 bug 的影响。

注意

  • 从 Databricks Runtime 12.0 ML 开始,sparkdl.xgboost 模块已弃用。 Databricks 建议迁移代码以改用 xgboost.spark 模块。 请参阅迁移指南
  • 不支持 xgboost 包中的以下参数:gpu_idoutput_marginvalidate_features
  • 不支持参数 sample_weighteval_setsample_weight_eval_set。 请改用参数 weightColvalidationIndicatorCol。 有关详细信息,请参阅适用于 PySpark 管道的 XGBoost
  • 不支持 base_marginbase_margin_eval_set 参数。 请改用参数 baseMarginCol。 有关详细信息,请参阅适用于 PySpark 管道的 XGBoost
  • 参数 missing 的语义与 xgboost 包不同。 在 xgboost 包中,无论 missing 值是什么,都会将 SciPy 稀疏矩阵中的零值视为缺失值。 对于 sparkdl 包中的 PySpark 估算器,除非设置 missing=0,否则不会将 Spark 稀疏向量中的零值视为缺失值。 如果你有一个稀疏训练数据集(缺失大多数特征值),Databricks 建议设置 missing=0 以减少内存消耗量并实现更好的性能。

分布式训练

Databricks Runtime ML 支持使用 num_workers 参数进行分布式 XGBoost 训练。 要使用分布式训练,请创建一个分类器或回归器并将 num_workers 设置为小于或等于群集上 Spark 任务槽总数的值。 若要使用所有 Spark 任务槽,请设置 num_workers=sc.defaultParallelism

例如:

classifier = XgboostClassifier(num_workers=sc.defaultParallelism)
regressor = XgboostRegressor(num_workers=sc.defaultParallelism)

分布式训练的限制

  • 不能将 mlflow.xgboost.autolog 与分布式 XGBoost 一起使用。
  • 不能将 baseMarginCol 与分布式 XGBoost 一起使用。
  • 不能在启用了自动缩放的群集上使用分布式 XGBoost。 有关禁用自动缩放的说明,请参阅启用自动缩放

GPU 训练

注意

Databricks Runtime 11.3 LTS ML 包括 XGBoost 1.6.1,它不支持计算功能 5.2 及以下的 GPU 群集。

Databricks Runtime 9.1 LTS ML 及更高版本支持使用 GPU 群集进行 XGBoost 训练。 要使用 GPU 群集,请将 use_gpu 设置为 True

例如:

classifier = XgboostClassifier(num_workers=N, use_gpu=True)
regressor = XgboostRegressor(num_workers=N, use_gpu=True)

故障排除

在多节点训练期间,如果遇到 NCCL failure: remote process exited or there was a network error 消息,则通常表示 GPU 之间的网络通信存在问题。 当 NCCL(NVIDIA 集体通信库)无法使用某些网络接口进行 GPU 通信时,就会出现此问题。

若要解决此问题,请将群集的 sparkConf 从 spark.executorEnv.NCCL_SOCKET_IFNAME 设置为 eth。 这实际上是将节点中所有工作器的环境变量 NCCL_SOCKET_IFNAME 设置为 eth

示例笔记本

此笔记本演示如何将 Python 包 sparkdl.xgboost 与 Spark MLlib 配合使用。 自 Databricks Runtime 12.0 ML 起,sparkdl.xgboost 包已弃用。

PySpark-XGBoost 笔记本

获取笔记本