使用 sparkdl.xgboost
对 XGBoost 模型进行分布式训练
重要
此功能目前以公共预览版提供。
注意
sparkdl.xgboost
从 Databricks Runtime 12.0 ML 开始已被弃用,在 Databricks Runtime 13.0 ML 及更高版本中已被删除。 若要了解如何将工作负荷迁移到 xgboost.spark
,请参阅已弃用的 sparkdl.xgboost 模块的迁移指南。
Databricks Runtime ML 包含基于 Python xgboost
包 sparkdl.xgboost.XgboostRegressor
和 sparkdl.xgboost.XgboostClassifier
的 PySpark 估算器。 可以基于这些估算器创建 ML 管道。 有关详细信息,请参阅适用于 PySpark 管道的 XGBoost。
Databricks 强烈建议 sparkdl.xgboost
用户使用 Databricks Runtime 11.3 LTS ML 或更高版本。 旧版 Databricks Runtime 受旧版 sparkdl.xgboost
中的 bug 的影响。
注意
- 从 Databricks Runtime 12.0 ML 开始,
sparkdl.xgboost
模块已弃用。 Databricks 建议迁移代码以改用xgboost.spark
模块。 请参阅迁移指南。 - 不支持
xgboost
包中的以下参数:gpu_id
、output_margin
、validate_features
。 - 不支持参数
sample_weight
、eval_set
和sample_weight_eval_set
。 请改用参数weightCol
和validationIndicatorCol
。 有关详细信息,请参阅适用于 PySpark 管道的 XGBoost。 - 不支持
base_margin
和base_margin_eval_set
参数。 请改用参数baseMarginCol
。 有关详细信息,请参阅适用于 PySpark 管道的 XGBoost。 - 参数
missing
的语义与xgboost
包不同。 在xgboost
包中,无论missing
值是什么,都会将 SciPy 稀疏矩阵中的零值视为缺失值。 对于sparkdl
包中的 PySpark 估算器,除非设置missing=0
,否则不会将 Spark 稀疏向量中的零值视为缺失值。 如果你有一个稀疏训练数据集(缺失大多数特征值),Databricks 建议设置missing=0
以减少内存消耗量并实现更好的性能。
分布式训练
Databricks Runtime ML 支持使用 num_workers
参数进行分布式 XGBoost 训练。 要使用分布式训练,请创建一个分类器或回归器并将 num_workers
设置为小于或等于群集上 Spark 任务槽总数的值。 若要使用所有 Spark 任务槽,请设置 num_workers=sc.defaultParallelism
。
例如:
classifier = XgboostClassifier(num_workers=sc.defaultParallelism)
regressor = XgboostRegressor(num_workers=sc.defaultParallelism)
分布式训练的限制
- 不能将
mlflow.xgboost.autolog
与分布式 XGBoost 一起使用。 - 不能将
baseMarginCol
与分布式 XGBoost 一起使用。 - 不能在启用了自动缩放的群集上使用分布式 XGBoost。 有关禁用自动缩放的说明,请参阅启用自动缩放。
GPU 训练
注意
Databricks Runtime 11.3 LTS ML 包括 XGBoost 1.6.1,它不支持计算功能 5.2 及以下的 GPU 群集。
Databricks Runtime 9.1 LTS ML 及更高版本支持使用 GPU 群集进行 XGBoost 训练。 要使用 GPU 群集,请将 use_gpu
设置为 True
。
例如:
classifier = XgboostClassifier(num_workers=N, use_gpu=True)
regressor = XgboostRegressor(num_workers=N, use_gpu=True)
故障排除
在多节点训练期间,如果遇到 NCCL failure: remote process exited or there was a network error
消息,则通常表示 GPU 之间的网络通信存在问题。 当 NCCL(NVIDIA 集体通信库)无法使用某些网络接口进行 GPU 通信时,就会出现此问题。
若要解决此问题,请将群集的 sparkConf 从 spark.executorEnv.NCCL_SOCKET_IFNAME
设置为 eth
。 这实际上是将节点中所有工作器的环境变量 NCCL_SOCKET_IFNAME
设置为 eth
。
示例笔记本
此笔记本演示如何将 Python 包 sparkdl.xgboost
与 Spark MLlib 配合使用。 自 Databricks Runtime 12.0 ML 起,sparkdl.xgboost
包已弃用。