📜  分类算法-决策树

📅  最后修改于: 2020-12-10 05:36:41             🧑  作者: Mango


决策树简介

通常,决策树分析是一种预测建模工具,可以应用于许多领域。决策树可以通过一种算法方法构建,该算法可以根据不同条件以不同方式拆分数据集。决策树是属于监督算法类别的最强大的算法。

它们可用于分类和回归任务。一棵树的两个主要实体是决策节点,在这里数据被拆分并离开,在这里我们得到结果。下面提供了用于预测一个人是否适合或不适合的二叉树示例,它提供了诸如年龄,饮食习惯和运动习惯等各种信息-

人

在上面的决策树中,问题是决策节点,最终结果是叶子。我们有以下两种类型的决策树-

  • 分类决策树-在这种决策树中,决策变量是分类的。上面的决策树是分类决策树的示例。

  • 回归决策树-在这种决策树中,决策变量是连续的。

实现决策树算法

基尼指数

它是成本函数的名称,用于评估数据集中的二进制拆分,并与分类目标变量“成功”或“失败”一起使用。

基尼指数值越高,同质性越高。理想的基尼系数值为0,最差的值为0.5(对于2类问题)。拆分的基尼系数可以通过以下步骤计算-

  • 首先,使用公式p ^ 2 + q ^ 2计算子节点的Gini指数,这是成功和失败概率的平方之和。

  • 接下来,使用该拆分的每个节点的加权Gini得分计算拆分的Gini指数。

分类和回归树(CART)算法使用Gini方法生成二进制拆分。

分割创作

拆分基本上包括数据集中的一个属性和一个值。我们可以通过以下三个部分在数据集中创建拆分-

  • 第1部分:计算基尼分数-我们在上一节中刚刚讨论了这一部分。

  • 第2部分:拆分数据集-可以定义为将数据集分为两列,每列具有一个属性的索引和该属性的拆分值。从数据集中获得左右两个组之后,我们可以使用第一部分中计算的基尼得分来计算split的值。分割值将决定属性将驻留在哪个组中。

  • 第三部分:评估所有分割-找到基尼得分并分割数据集后的下一部分是所有分割的评估。为此,首先,我们必须检查与每个属性关联的每个值作为候选拆分。然后,我们需要通过评估分割成本来找到最佳分割。最佳拆分将用作决策树中的节点。

建树

我们知道一棵树有根节点和终端节点。创建根节点后,我们可以通过以下两个部分来构建树:

第1部分:终端节点创建

在创建决策树的终端节点时,重要的一点是确定何时停止增长树或创建其他终端节点。可以通过以下两个条件来完成,即最大树深度和最小节点记录-

  • 最大树深度-顾名思义,这是树中根节点之后的最大节点数。一旦一棵树达到最大深度,即一棵树获得最大数量的终端节点,我们就必须停止添加终端节点。

  • 最小节点记录-可以定义为给定节点负责的最小训练模式数。一旦树达到这些最低节点记录或低于此最低节点记录,我们就必须停止添加终端节点。

终端节点用于做出最终预测。

第2部分:递归拆分

正如我们了解何时创建终端节点一样,现在我们可以开始构建树了。递归拆分是一种构建树的方法。在这种方法中,一旦创建了一个节点,我们就可以在每一组数据上递归地创建子节点(添加到现有节点上的节点),这些子节点是通过拆分数据集,一次又一次地调用相同的函数而生成的。

预测

构建决策树后,我们需要对其进行预测。基本上,预测包括使用特定提供的数据行浏览决策树。

如上所述,我们可以借助递归函数进行预测。左侧或右侧子节点再次调用相同的预测例程。

假设条件

以下是我们在创建决策树时所做的一些假设-

  • 在准备决策树时,训练集作为根节点。

  • 决策树分类器更喜欢对要素值进行分类。如果要使用连续值,则必须先离散化它们,然后再建立模型。

  • 根据属性的值,记录将以递归方式分布。

  • 统计方法将用于将属性放置在任何节点位置,即根节点或内部节点。

用Python实现

在以下示例中,我们将在Pima印度糖尿病上实现决策树分类器-

首先,从导入必要的Python包开始-

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

接下来,如下所示从其网络链接下载iris数据集:

col_names = ['pregnant', 'glucose', 'bp', 'skin', 'insulin', 'bmi', 'pedigree', 'age', 'label']
pima = pd.read_csv(r"C:\pima-indians-diabetes.csv", header=None, names=col_names)
pima.head()
pregnant    glucose  bp    skin  insulin  bmi   pedigree    age   label
0       6         148      72    35     0       33.6    0.627     50      1
1       1         85       66    29     0       26.6    0.351     31      0
2       8         183      64     0     0       23.3    0.672     32      1
3       1         89       66    23     94      28.1    0.167     21      0
4       0         137      40    35     168     43.1    2.288     33      1

现在,将数据集分为要素和目标变量,如下所示:

feature_cols = ['pregnant', 'insulin', 'bmi', 'age','glucose','bp','pedigree']
X = pima[feature_cols] # Features
y = pima.label # Target variable

接下来,我们将数据分为训练和测试拆分。以下代码将数据集拆分为70%的训练数据和30%的测试数据-

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)

接下来,借助sklearn的DecisionTreeClassifier类训练模型,如下所示-

clf = DecisionTreeClassifier()
clf = clf.fit(X_train,y_train)

最后,我们需要进行预测。可以在以下脚本的帮助下完成-

y_pred = clf.predict(X_test)

接下来,我们可以获得准确性得分,混淆矩阵和分类报告,如下所示:

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
result = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(result)
result1 = classification_report(y_test, y_pred)
print("Classification Report:",)
print (result1)
result2 = accuracy_score(y_test,y_pred)
print("Accuracy:",result2)

输出

Confusion Matrix:
[[116 30]
[ 46 39]]
Classification Report:
            precision   recall   f1-score    support
      0       0.72      0.79       0.75     146
      1       0.57      0.46       0.51     85
micro avg     0.67      0.67       0.67     231
macro avg     0.64      0.63       0.63     231
weighted avg  0.66      0.67       0.66     231

Accuracy: 0.670995670995671   

可视化决策树

上面的决策树可以在以下代码的帮助下可视化-

from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus

dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,
      filled=True, rounded=True,
      special_characters=True,feature_names = feature_cols,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('Pima_diabetes_Tree.png')
Image(graph.create_png())

框