Python 机器学习中,CART(Classification And Regression Trees)算法用于构建决策树,用于分类和回归任务。剪枝(Pruning)是一种避免决策树过拟合的技术,通过减少树的大小来提高模型的泛化能力。CART剪枝分为预剪枝和后剪枝两种主要方式。

参考文档: Python 机器学习 决策树 cart剪枝-CJavaPy

1、预剪枝(Pre-Pruning)

预剪枝涉及在决策树完全生成之前停止树的增长。可以通过设置一些停止条件来实现,

1)树达到预定的最大深度(max_depth)

2)节点中的样本数量少于预定阈值(min_samples_split

3)分割后的节点的信息增益小于某个阈值,

4)节点中样本的纯度(比如,用基尼指数或熵测量)已经足够高。

预剪枝简单易实现,但可能过于保守,有时会导致模型欠拟合。

2、后剪枝(Post-Pruning)

后剪枝,也称为剪枝,是在决策树完全生成之后进行的。它通过删除树的部分子树或节点来减少树的复杂度,选择那些能够提高交叉验证数据集准确率的剪枝。后剪枝策略包括成本复杂度剪枝(Cost Complexity Pruning)、错误率降低剪枝(Reduced Error Pruning)和最小错误剪枝(Minimum Error Pruning)。

成本复杂度剪枝(Cost Complexity Pruning)是通过最小化一个称为成本复杂度的函数来实现剪枝。这个函数是树的错误率和树的复杂度的加权和。

错误率降低剪枝(Reduced Error Pruning)是从叶节点开始,尝试移除每个节点,如果移除后对验证集的分类准确性没有影响或者有所提高,则进行剪枝。

最小错误剪枝(Minimum Error Pruning)是在每个节点上应用一个简单的启发式规则,如果剪枝不会导致错误率增加,则执行剪枝。

3、cart剪枝的作用

决策树通过递归地选择最佳属性将数据集分割,构建出一个树状的分类模型。但一个没有限制的决策树很容易过度拟合训练数据,导致模型在未知数据上的泛化能力下降。为了解决这个问题,决策树剪枝技术被提出来,以提高决策树模型的泛化能力。通过剪掉不必要的节点,减少模型对训练数据中噪声的拟合,从而提高模型在未见数据上的泛化能力。剪枝后的决策树模型更简洁,易于理解和解释,有利于提高模型的可解释性。简化后的模型在预测时计算量更小,预测速度更快。剪枝是提高决策树模型泛化能力和效率的重要技术之一,是决策树算法中不可或缺的一部分。在实际应用中,通过适当选择剪枝策略和参数,可以大幅提升模型的性能。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载iris数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 训练一个决策树模型(未剪枝)
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
# 预测测试集
y_pred = clf.predict(X_test)
# 评估模型
accuracy_without_pruning = accuracy_score(y_test, y_pred)
# 训练一个决策树模型(使用代价复杂度剪枝)
clf_pruned = DecisionTreeClassifier(random_state=42, ccp_alpha=0.01)  # ccp_alpha是剪枝的复杂度参数
clf_pruned.fit(X_train, y_train)
# 预测测试集
y_pred_pruned = clf_pruned.predict(X_test)
# 评估模型
accuracy_with_pruning = accuracy_score(y_test, y_pred_pruned)
print(accuracy_without_pruning, accuracy_with_pruning)

4、cart剪枝的应用

CART(Classification and Regression Trees)算法用于构建决策树,既可以用于分类问题也可以用于回归问题。一棵完全生长的决策树往往会过于复杂,导致过拟合,即在训练数据上表现很好但在未见过的数据上表现不佳。为了解决这个问题,可以采用剪枝(pruning)技术来简化决策树,提高模型的泛化能力。CART 决策树的构建过程采用贪心算法,不断地划分数据集,直到满足停止条件。 DecisionTreeClassifier 是 scikit-learn 中用于解决分类问题的决策树算法实现。常用参数如下,

使用代码,

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 训练决策树模型
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X_train, y_train)
# 成本复杂度剪枝参数
path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
# 对每个ccp_alpha训练一个决策树并评估其性能
trees = []
for ccp_alpha in ccp_alphas:
    tree = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
    tree.fit(X_train, y_train)
    trees.append(tree)
# 选择最佳的ccp_alpha值(可根据测试集性能来选择)
# 这里简化了选择过程,实际应用中应该使用交叉验证等方法
# 可视化决策树
plt.figure(figsize=(20,10))
plot_tree(trees[-1], filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.draw()
plt.show()

参考文档: Python 机器学习 决策树 cart剪枝-CJavaPy