如何为基于树的 Apache SparkML 管道模型提取特征信息How to extract feature information for tree-based Apache SparkML pipeline models

拟合基于树的模型(例如决策树、随机林或梯度增强树)时,能够查看特征重要性级别以及特征名称将非常有用。When you are fitting a tree-based model, such as a decision tree, random forest, or gradient boosted tree, it is helpful to be able to review the feature importance levels along with the feature names. 通常,SparkML 中的模型适合充当管道的最后一个阶段。Typically models in SparkML are fit as the last stage of the pipeline. 若要使用树模型从管道中提取相关特征信息,必须提取恰当的管道阶段。To extract the relevant feature information from the pipeline with the tree model, you must extract the correct pipeline stage. 可以从 VectorAssembler 对象提取特征名称:You can extract the feature names from the VectorAssembler object:

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[indexer, assembler, decision_tree)
DTmodel = pipeline.fit(train)
va = dtModel.stages[-2]
tree = DTmodel.stages[-1]

display(tree) #visualize the decision tree model
print(tree.toDebugString) #print the nodes of the decision tree model

list(zip(va.getInputCols(), tree.featureImportances))

还可以在管道的最后一个阶段使用交叉验证程序来调整基于树的模型。You can also tune a tree-based model using a cross validator in the last stage of the pipeline. 若要可视化决策树并打印特征重要性级别,请从 CrossValidator 对象中提取 bestModelTo visualize the decision tree and print the feature importance levels, you extract the bestModel from the CrossValidator object:

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

cv = CrossValidator(estimator=decision_tree, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
pipelineCV = Pipeline(stages=[indexer, assembler, cv)
DTmodelCV = pipelineCV.fit(train)
va = DTmodelCV.stages[-2]
treeCV = DTmodelCV.stages[-1].bestModel

display(treeCV) #visualize the best decision tree model
print(treeCV.toDebugString) #print the nodes of the decision tree model

list(zip(va.getInputCols(), treeCV.featureImportances))

display 函数仅可视化决策树模型。The display function visualizes decision tree models only. 请参阅机器学习可视化效果See Machine learning visualizations.