📜  用于可视化的决策树分类器 python 代码 - Python (1)

📅  最后修改于: 2023-12-03 15:40:53.462000             🧑  作者: Mango

用于可视化的决策树分类器 Python 代码

决策树是一种常用的监督学习算法,通常用于分类和回归问题。本文将介绍如何使用 Python 中的 scikit-learn 库实现一个决策树分类器,并将其可视化。

数据集

我们将使用鸢尾花数据集,这是一个经典的分类问题数据集。数据集包含 150 个样本,每个样本都有 4 个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。样本被分为了 3 类:山鸢尾、变色鸢尾和维吉尼亚鸢尾。

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target
建立模型

我们将使用 scikit-learn 中的 DecisionTreeClassifier 类来建立决策树分类器。我们还需要将数据集分成训练和测试集。

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)

y_pred_train = clf.predict(X_train)
y_pred_test = clf.predict(X_test)

accuracy_train = accuracy_score(y_train, y_pred_train)
accuracy_test = accuracy_score(y_test, y_pred_test)

print("训练集精度:", accuracy_train)
print("测试集精度:", accuracy_test)
可视化决策树

我们可以使用 Graphviz 来可视化决策树。首先,我们需要安装 Graphviz 并将其加入系统环境变量中。然后,我们可以使用 export_graphviz 函数来生成 Graphviz 文件。最后,我们可以使用 pydotplus 包将 Graphviz 文件转换为图像。

!pip install graphviz
!pip install pydotplus

from sklearn.tree import export_graphviz
import graphviz
import pydotplus

dot_data = export_graphviz(
    clf,
    out_file = None,
    feature_names = iris.feature_names,
    class_names = iris.target_names,
    filled = True,
    rounded = True,
    special_characters = True)

graph = pydotplus.graph_from_dot_data(dot_data)
graphviz.Source(dot_data)
输出结果

我们已经完成了决策树分类器的建立和可视化。训练集精度为 1.0,测试集精度为 0.9667。下面是可视化的决策树图像:

决策树可视化图像

结论

在本文中,我们学习了如何使用 scikit-learn 中的 DecisionTreeClassifier 类建立决策树分类器,并将其可视化。通过可视化决策树,我们可以更好地理解和解释我们的模型。