1、基本流程

        决策树(decision tree)是一类常见的机器学习方法.以二分类任务为例,我们希望从给定训练数据集学得一个模型用以对新示例进行分类,这个把样本分类的任务,可看作对“当前样本属于正类吗?”这个问题的“决策”或“判定”过程.顾名思义,决策树是基于树结构来进行决策的,这恰是人类在面临决策问题时一种很自然的处理机制.例如,我们要对“这是好瓜吗?”这样的问题进行决策时,通常会进行一系列的判断或“子决策”:我们先看“它是什么颜色?”,如果是“青绿色”,则我们再看“它的根蒂是什么形态?”,如果是“蜷缩”,我们再判断“它敲起来是什么声音?”,最后,我们得出最终决策:这是个好瓜.这个决策过程如图所示.

        一般的,一棵决策树包含一个根结点、若干个内部结点和若干个叶结点;叶结点对应于决策结果,其他每个结点则对应于一个属性测试; 每个结点包含的样本集合根据属性测试的结果被划分到子结点中;根结点包含样本全集.从根结点到每个叶结点的路径对应了一个判定测试序列.决策树学习的目的是为了产生一棵泛化能力强,即处理未见示例能力强的决策树,其基本流程遵循简单且直观的“分而治之”(divide-and-conquer)策略如图所示

        显然,决策树的生成是一个递归过程.在决策树基本算法中,有三种情形会导致递归返回:

(1)当前结点包含的样本全属于同一类别,无需划分;

(2) 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分;

(3) 当前结点包含的样本集合为空,不能划分

        在第(2)种情形下,我们把当前结点标记为叶结点,并将其类别设定为该结点所含样本最多的类别; 在第(3)种情形下,同样把当前结点标记为叶结点,但将其类别设定为其父结点所含样本最多的类别. 注意这两种情形的处理实质不同:情形(2)是在利用当前结点的后验分布,而情形(3)则是把父结点的样本分布作为当前结点的先验分布.

2、划分选择

        由上面的决策树伪代码可看出,决策树学习的关键是第 8 行,即如何选择最优划分属性,一般而言,随着划分过程不断进行,我们希望决策树的分支结点所包含的样本尽可能属于同一类别,即结点的“纯度”(purity)越来越高.

2.1、信息增益

        “信息”(information entropy)是度量样本集合纯度最常用的一种指标. 假定当前样本集合D中第 k类样本所占的比例为p_{k}(k=1,2,..., \left| \gamma \right|), ​则D的信息熵定义为

Ent(D)=-\sum_{k=1}^{ \left| \gamma \right| }p_{k}\log_{2}p_{k}        (1)

计算信息熵时约定: 若 p=0,则 p\log_{2}p=0.

Ent(D)的最小值为0,最大值为 \log_{2}{ \left| \gamma \right| }.

Ent(D)的值越小,则D的纯度越高.

        假定离散属性aV个可能的取值 { \left\{ a^{1},a^{2},...,a^{V} \right\} }, 若使用​a来对样本集 D进行划分,则会产生V个分支结点, 其中第v个分支结点包含了 D中所有在属性 a上取值为a^{v}的样本,记为 D^{v}. 我们可根据(1)计算 D^{v}的信息熵,再考虑到不同的分支结点所包含的样本数不同, 给分支结点赋予权重{ \left| D^{v} \right| }/{ \left| D \right| }, 即样本数越多的分支结点的影响越大,于是可计算出用属性 a对样本集 D进行划分所获得的“信息增益”(information gain)

Gain(D,a)=Ent(D)-\sum_{v=1}^{V} \frac{ \left| D^{v} \right| }{ \left| D \right| } Ent(D^{v})        (2)

        一般而言,信息增益越大,则意味着使用属性 a来进行划分所获得的“纯度提升”越大. 因此, 我们可用信息增益来进行决策树的划分属性选择,即在伪代码算法第8行选择属性 a_{*} = \underset{a\in A}{argmax} \: Gain(D,a). 著名的ID3 决策树学习算法[Quinlan,1986] 就是以信息增益为准则来选择划分属性.(ID3名字中的ID是lterative Dichotomiser(迭代二分器)的简称)

2.2、增益率

        在上面的介绍中,我们如果把数据的索引也作为一个候选划分属性, 如果有17个数据, 则根据式(2)可计算出它的信息增益为 0.998.远大于其他候选划分属性.这很容易理解:“编号”将产生 17 个分支每个分支结点仅包含一个样本,这些分支结点的纯度已达最大.然而,这样的决策树显然不具有泛化能力,无法对新样本进行有效预测.

        实际上,信息增益准则对可取值数目较多的属性有所偏好,为减少这种偏好可能带来的不利影响,著名的 C4.5 决策树算法[Quinlan,1993] 不直接使用信息增益, 而是使用“增益率”(gain ratio)来选择最优划分属性. 采用与式(2)相同的符号表示,增益率定义为

Gain \_ ratio(D,a)=\frac{Gain(D,a)}{IV(a)}        (3)

其中

IV(a)=-\sum_{v=1}^{V} \frac{ \left| D^{v} \right| }{ \left| D \right| } \log_{2} \frac{ \left| D^{v} \right| }{ \left| D \right| }        (4)

称为属性 a的“固有值”(intrinsic value)[Quinlan,1993]. 属性a的可能取值数目越多(即 V越大),则IV(a)的值通常会越大.

        需注意的是,**增益率准则对可取值数目较少的属性有所偏好**,因此,C4.5算法并不是直接选择增益率最大的候选划分属性, 而是使用了一个启发式[Quinlan,1993]: 先从候选划分属性中找出信息增益高于平均水平的属性, 再从中选择增益率最高的.

2.3、基尼指数

        CART是Classificationand Regression Tree的简称,这是一种著名的决策树学习算法, 分类和回归任务都可用.

        CART决策树[Breiman et al,1984 ]使用“基尼指数”(Gini index)来选择划分属性. 采用与式(1) 相同的符号,数据集D的纯度可用基尼值来度量:

\begin{aligned} Gini(D)&=\sum_{k=1}^{\left| \gamma \right|} \sum_{ {k}' \neq k } p_{k} p_{​{k}'} \\ &=1-\sum_{k=1}^{\left| \gamma \right|} p_{k}^{2} \end{aligned}        (5)

        直观来说, Gini(D)反映了从数据集 D中随机抽取两个样本其类别标记不一致的概率因此, Gini(D)越小,则数据集 ​的纯度越高.

        采用与式(2)相同的符号表示,属性 a的基尼指数定义为

Gini\_index(D,a)=\sum_{v=1}^{V} \frac{ \left| D^{v} \right| }{ \left| D \right| } Gini(D^{v})        (6)

        于是,我们在候选属性集合 A中, 选择那个使得划分后基尼指数最小的属性作为最优划分属性, 即 a_{*} = \underset{a\in A}{argmin} \: Gain \_ index(D,a).

3、剪枝处理

        剪枝(pruning)是决策树学习算法对付“过拟合”的主要手段. 在决策树学习中,为了尽可能正确分类训练样本,结点划分过程将不断重复,有时会造成决策树分支过多,这时就可能因训练样本学得“太好”了,以致于把训练集自身的一些特点当作所有数据都具有的一般性质而导致过拟合. 因此, 可通过主动去掉一些分支来降低过拟合的风险.

         决策树剪枝的基本策略有“预剪枝”(prepruning)和“后剪枝”(postpruning)[Quinlan,1993]. 预剪枝是指在决策树生成过程中, 对每个结点在划分前先进行估计, 若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点;后剪枝则是先从训练集生成一棵完整的决策树然后自底向上地对非叶结点进行考察, 若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点.

        如何判断决策树泛化性能是否提升呢?这可使用性能评估方法. 这里假定采用留出法, 即预留一部分数据用作“验证集”以进行性能评估. 例如对包含17条数据的西瓜数据集, 我们将其随机划分为两部分, 如表 4.2 所示, 编号为{1,2,3,6,7,10,14,15,16,17} 的样例组成训练集,编号为{4,5,8,9,11,12,13}的样例组成验证集.

         假定我们采用信息增益准则来进属性选择, 则从表4.2的训练集将会生成一棵如图 4.5 所示的决策树. 为便于讨论, 我们对图中的部分结点做了编号.

3.1、预剪枝

        基于信息增益准则, 我们会选取属性“脐部”来对训练集进行划分,并产生3个分支, 如图 4.6 所示. 然而是否应该进行这个划分呢?预剪枝要对划分前后的泛化性能进行估计.

        在划分之前, 所有样例集中在根结点. 若不进行划分, 则根据决策树的伪代码第6行, 该结点将被标记为叶结点, 其类别标记为训练样例数最多的类别,假设我们将这个叶结点标记为“好瓜”. 用表 4.2的验证集对这个单结点决策树进行评估, 则编号为{4,5,8}的样例被分类正确, 另外 4 个样例分类错误, 于是, 验证集精度为 \frac{3}{7}\times 100\%=42.9\%.

        在用属性“脐部”划分之后,图 4.6 中的结点②③④分别含编号为{1,2,3,14}、{6,7,15,17}、{10,16} 的训练样例, 因此这3 个结点分别被标记为叶结点“好瓜”、“好瓜”、“坏瓜”. 此时, 验证集中编号为{4,5,8,11,12}的样例被分类正确, 验证集精度为\frac{5}{7}\times 100\%=71.4\% > 42.9\%. 于是, 用“脐部”进行划分得以确定.

        然后,决策树算法应该对结点②进行划分, 基于信息增益准则将挑选出划分属性“色泽”. 然而, 在使用“色泽”划分后, 编号为{5}的验证集样本分类结果会由正确转为错误, 使得验证集精度下降为 57.1%. 于是, 预枝策略将禁止结点②被划分.

        对结点③, 最优划分属性为“根蒂”, 划分后验证集精度仍为 71.4%. 这个划分不能提升验证集精度, 于是, 预剪枝策略禁止结点③被划分.

        对结点④, 其所含训练样例已属于同一类,不再进行划分.

        于是,基于预剪枝策略从表 4.2 数据所生成的决策树如图 4.6 所示, 其验证集精度为 71.4%. 这是一棵仅有一层划分的决策树.

        对比图4.6 和图4.5可看出, 预剪枝使得决策树的很多分支都没有“展开”,这不仅降低了过拟合的风险,还显著减少了决策树的训练时间开销和测试时间开销. 但另一方面, 有些分支的当前划分虽不能提升泛化性能、甚至可能导致泛化性能暂时下降, 但在其基础上进行的后续划分却有可能导致性能显著提高; 预剪枝基于“贪心”本质禁止这些分支展开, 给预剪枝决策树带来了欠拟合的风险.

3.2、后剪枝

        后剪枝先从训练集生成一棵完整决策树, 例如基于表4.2 的数据我们得到如图4.5 所示的决策树易知,该决策树的验证集精度为 42.9%.

        后剪枝首先考察图 4.5 中的结点⑥. 若将其领衔的分支剪除, 则相当于把⑥替换为叶结点. 替换后的叶结点包含编号为{7,15}的训练样本, 于是, 该叶结点的类别标记为“好瓜”, 此时决策树的验证集精度提高至 57.1%. 于是后剪枝策略决定剪枝, 如图4.7 所示.

        然后考察结点⑤, 若将其领的子树替换为叶结点, 则替换后的叶结点包含编号为{6,7,15}的训练样例, 叶结点类别标记为“好瓜”, 此时决策树验证集精度仍为 57.1%. 于是, 可以不进行剪枝.

        对结点②, 若将其领衔的子树替换为叶结点, 则替换后的叶结点包含编号为{1,2,3,14}的训练样例, 叶结点标记为“好瓜”. 此时决策树的验证集精度提高至 71.4%. 于是, 后剪枝策略决定剪枝.

        对结点③和①, 若将其领衔的子树替换为叶结点,则所得决策树的验证集精度分别为 71.4%与42.9%, 均未得到提高于是它们被保留.(此种情形下验证集精度虽无提高, 但根据奥卡姆剃刀准则, 剪枝后的模型更好. 因此, 实际的决策树算法在此种情形下通常要进行剪枝. 为绘图的方便,这里采取了不剪枝的保守策略)

        最终, 基于后剪枝策略从表 4.2 数据所生成的决策树如图4.7 所示, 其验证集精度为 71.4%.

        对比图 4.7 和图4.6 可看出, 后枝决策树通常比预枝决策树保留了更多的分支. 一般情形下, 后前枝决策树的欠拟合风险很小, 泛化性能往往优于预剪枝决策树. 但后剪枝过程是在生成完全决策树之后进行的, 并且要自底向上地对树中的所有非叶结点进行逐一考察, 因此其训练时间开销比未剪枝决策树和预剪枝决策树都要大得多.

4、连续与缺失值

4.1、连续值处理

        到目前为止我们仅讨论了基于离散属性来生成决策树. 现实学习任务中常会遇到连续属性, 有必要讨论如何在决策树学习中使用连续属性.

        由于连续属性的可取值数目不再有限,因此, 不能直接根据连续属性的可取值来对结点进行划分. 此时, 连续属性离散化技术可派上用场. 最简单的策略是采用二分法(bi-partiton)对连续属性进行处理, 这正是C4.5决策树算法中采用的机制[Quinlan,1993].

        给定样本集D和连续属性, 假定 aD上出现了n个不同的取值将这些值从小到大进行排序,记为 \{ a^{1},a^{2},...,a^{n} \}. 基于划分点 t可将D分为子集 D_{t}^{-}和​D_{t}^{+}, 其中 D_{t}^{-}包含那些在属性上取值不大于t的样本, 而D_{t}^{+}则包含那些在属性a上取值大于t 的样本, 显然对相邻的属性取值 a^{i}a^{i+1}来说, t在区间 [a^{i},a^{i+1})中取任值所生的划分结果相同. 因此, 对连续属性a, 我们可考察包含 n -1个元素的候选划分点集合

T_{a}= \left\{ \frac{ a^{i}+a^{i+1} }{2} \;| \; 1 \leqslant i \leqslant n-1 \right\}        (7)

        即把区间 [a^{i},a^{i+1})的中位点 \frac{ a^{i}+a^{i+1} }{2}作为候选划分点. 然后, 我们就可像离散属性值一样来考察这些划分点, 选取最优的划分点进行样本集合的划分. 例如可对式(2)稍加改造:

\begin{aligned} Gain(D,a)&=\underset{t \in T_{a}}{max}\; Gain(D,a,t) \\ &= \underset{t \in T_{a}}{max}\; Ent(D) - \sum_{\lambda \in \left\{ -,+ \right\}} \frac{ \left| D^{v} \right| }{ \left| D \right| } Ent(D_{t}^{\lambda}) \end{aligned}        (8)

        其中Gain(D,a,t)是样本集D基于划分点t二分后的信息增益. 于是, 我们就可选择使Gain(D,a,t)最大化的划分点.

        需注意的是, 与离散属性不同, 若当前结点划分属性为连续属性, 该属性还可作为其后代结点的划分属性. 例如在父结点上使用了“密度<0.381”, 不会禁止在子结点上使用“密度<0.294”.

4.2、缺失值处理

        现实任务中常会遇到不完整样本, 即样本的某些属性值缺失. 例如由于诊测成本、隐私保护等因素, 患者的医疗数据在某些属性上的取值(如 HIV 测试结果)未知; 尤其是在属性数目较多的情况下, 往往会有大量样本出现缺失值. 如果简单地放弃不完整样本, 仅使用无缺失值的样本来进行学习, 显然是对数据信息极大的浪费. 例如, 表 4.4 是出现缺失值的版本, 如果放弃不完整样本,则仅有编号{4.7,14,16}的4个样本能被使用, 显然, 有必要考虑利用有缺失属性值的训练样例来进行学习.

         我们需解决两个问题:

(1)如何在属性值缺失的情况下进行划分属性选择?

(2)给定划分属性,若样本在该属性上的值缺失如何对样本进行划分?

         给定训练集D和属性 a, \tilde{D}表示 D中在属性 a上没有缺失值的样本子集. 对问题(1), 显然我们仅可根据 \tilde{D}来判断属性的优劣假定属性aV个可取值 \left\{ a^{1},a^{2},...,a^{V} \right\}, 令 \tilde{D}^{v}表示\tilde{D}中在属性 a上取值为a^{v}的样本子集, \tilde{D}_{k}表示 \tilde{D}中属于第 k(k=1,2,...,|\gamma |)的样本子集, 则显然有\tilde{D}=\bigcup_{k=1}^{|\gamma|} \tilde{D}_{k}\tilde{D}=\bigcup_{v=1}^{V} \tilde{D}^{v}. 假定我们为每个样本 x赋予一个权重 w_{x}(在决策树学习开始阶段根结点中各样本的权重初始化为 1), 并定义

\rho =\frac{ \sum_{x \in \tilde{D} }w_{x} } { \sum_{x \in D }w_{x} }        (9)

\tilde{p}_{k} =\frac{ \sum_{x \in \tilde{D}_{k} }w_{x} } { \sum_{x \in \tilde{D} }w_{x} } \: \: (1\leqslant k \leqslant |\gamma |)        (10)

\tilde{r}_{v} =\frac{ \sum_{x \in \tilde{D}^{v} }w_{x} } { \sum_{x \in \tilde{D} }w_{x} } \: \: (1\leqslant v \leqslant V)        (11)

直观地看,对属性a, \rho表示无缺失值样本所占的比例, \tilde{p}_{k}表示缺失值样本中第 k类所占的比例, \tilde{r}_{v}则表示无缺失值本中在属性 a上取值 a^{v}的样本所占的比例. 显然, \sum_{k=1}^{|\gamma|} \tilde{p}_{k}=1, \sum_{v=1}^{V} \tilde{r}_{v}=1.

        基于上述定义我们可将信息增益的计算式(2)推广为

\begin{aligned} Gain(D,a)&=\rho \times Gain(\tilde{D},a) \\ &=\rho \times \left( Ent(\tilde{D})-\sum_{v=1}^{V} \tilde{r}_{v}Ent(\tilde{D}^{v}) \right) \end{aligned}        (12)

其中由式(1), 有

Ent(\tilde{D})=-\sum_{k=1}^{|\gamma|} \tilde{p}_{k} {\log}_{2}\tilde{p}_{k}

        对问题(2), 若样本 x在划分属性a上的取值已知, 则将 x与其取值对应的子结点, 且样本权值在子结点中保持为w_{x}. 若样本 x在划分属性 a上的取值未知, 则将 x同时划入所有子结点, 且样本权值在与属性值a^{v}对应的子结点中调整为 \tilde{r}_{v} \cdot w_{x}; 直观地看, 这就是让同一个样本以不同的概率划入到不同的子结点中去.

        C4.5算法使用了上述解决方案[Quinlan,1993]. 下面我们以表4.4 的数据集为例来生成一棵决策树.

        上述结点划分过程递归执行,最终生成的决策树如图 4.9 所示

5、代码实现

5.1、使用Scikit-learn库

官网:API Reference — scikit-learn 1.3.0 documentation

以下是使用Python的Scikit-learn库实现决策树的代码示例:

下面我们以考试成绩来进行一个分类例子。某班学生语、数、英三科考试成绩。每科都在80分以上的就算优秀。优秀用1表示,非优秀用0表示。假设我并不知道优秀的标准计算方式,只有一堆学生成绩和评定结果。我如何根据这些数据推算出新学生的成绩评定呢?

from sklearn import tree

# 语数英三科成绩
data_set = [[80, 90, 70], [60, 90, 87], [98, 85, 97], [40, 70, 50], [89, 90, 87], [67, 85, 74], [87, 82, 80],
            [100, 76, 79]]
# 三科成绩对应是澡优秀,1为优秀
labels = [0, 0, 1, 0, 1, 0, 1, 0]

# 生成决策树分类器
clf = tree.DecisionTreeClassifier()
# 分类器训练数据
clf = clf.fit(data_set, labels)

# 分类器预测新成绩数据
print(clf.predict([[80, 45, 76], [90, 97, 84]]))

结果:

[0,1]
# 第一个成绩是0,非优秀,
# 第二个成绩是1,优秀。

5.2、自己手写线性回归

这里使用的是ID3 决策树

没有使用剪枝处理

# -*- coding: UTF-8 -*-
from typing import List, Tuple, Any

import matplotlib.pyplot as plt
from math import log
import operator

"""
这里用的是ID3 决策树的方式
"""

def createDataSet():
	dataSet = [[0, 0, 0, 0, 'no'],						
			[0, 0, 0, 1, 'no'],
			[0, 1, 0, 1, 'yes'],
			[0, 1, 1, 0, 'yes'],
			[0, 0, 0, 0, 'no'],
			[1, 0, 0, 0, 'no'],
			[1, 0, 0, 1, 'no'],
			[1, 1, 1, 1, 'yes'],
			[1, 0, 1, 2, 'yes'],
			[1, 0, 1, 2, 'yes'],
			[2, 0, 1, 2, 'yes'],
			[2, 0, 1, 1, 'yes'],
			[2, 1, 0, 1, 'yes'],
			[2, 1, 0, 2, 'yes'],
			[2, 0, 0, 0, 'no']]
	labels = ['F1-AGE', 'F2-WORK', 'F3-HOME', 'F4-LOAN']		
	return dataSet, labels


def createTree(dataset,labels,featLabels):
	# 获取dataset的最后一行
	classList = [example[-1] for example in dataset]
	# 如果只有一个种类,即全是 'yes' 或者全是 'no'
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	# 如果只有一个特征
	if len(dataset[0]) == 1:
		return majorityCnt(classList)
	# 得到信息增益最大值对应的特征下标
	bestFeat = chooseBestFeatureToSplit(dataset)
	bestFeatLabel = labels[bestFeat]
	featLabels.append(bestFeatLabel)
	myTree = {bestFeatLabel:{}}
	del labels[bestFeat]
	# 获取信息增益最大值对应的特征所在列的所有数据
	featValue = [example[bestFeat] for example in dataset]
	uniqueVals = set(featValue)
	for value in uniqueVals:
		sublabels = labels[:]
		# 字典的树状结构的实现方式
		# {'F3-HOME': {}}
		# {'F2-WORK': {}}
		# {'F2-WORK': {0: 'no'}}
		# {'F2-WORK': {0: 'no', 1: 'yes'}}
		# {'F3-HOME': {0: {'F2-WORK': {0: 'no', 1: 'yes'}}}}
		myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sublabels,featLabels)
	return myTree

def majorityCnt(classList):
	"""
	在只有一个特征的情况下,获取该特征数据量最大的类型
	:param classList:
	:return: 在这个数据里,输出为 'yes' 或者 'no'
	"""
	classCount={}
	for vote in classList:
		if vote not in classCount.keys():classCount[vote] = 0
		classCount[vote] += 1
	# operator.itemgetter(1):获取下标为1的元素
	# key=operator.itemgetter(1):按下标为1的元素进行排序,这边指按value排序
	# reverse=True:降序排序
	sortedclassCount: list[tuple[Any, int]] = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
	return sortedclassCount[0][0]

def chooseBestFeatureToSplit(dataset):
	"""
	选择最好的特征用于作为分类节点
	:param dataset:
	:return:
	"""
	numFeatures = len(dataset[0]) - 1
	# 信息熵
	baseEntropy = calcShannonEnt(dataset)
	bestInfoGain = 0
	bestFeature = -1
	for i in range(numFeatures):
		# 获取特征对应的列数据
		featList = [example[i] for example in dataset]
		# 去重
		uniqueVals = set(featList)
		newEntropy = 0
		for val in uniqueVals:
			subDataSet = splitDataSet(dataset,i,val)
			prob = len(subDataSet)/float(len(dataset))
			newEntropy += prob * calcShannonEnt(subDataSet)
		# 信息增益
		infoGain = baseEntropy - newEntropy
		if (infoGain > bestInfoGain):
			bestInfoGain = infoGain
			bestFeature = i	
	return bestFeature
			

def splitDataSet(dataset,axis,val):
	"""
	将dataset数据按axis列的val类型进行切分
	:param dataset:
	:param axis:
	:param val:
	:return: dataset数据在axis列的值为val的所有行
	"""
	retDataSet = []
	for featVec in dataset:
		if featVec[axis] == val:
			# 将用于切分的axis列去掉,已经使用过了就不参与后面的计算了
			reducedFeatVec = featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)
	return retDataSet
			
def calcShannonEnt(dataset):
	"""
	计算信息熵
	:param dataset:
	:return:
	"""
	numexamples = len(dataset)
	labelCounts = {}
	for featVec in dataset:
		currentlabel = featVec[-1]
		if currentlabel not in labelCounts.keys():
			labelCounts[currentlabel] = 0
		labelCounts[currentlabel] += 1
		
	shannonEnt = 0
	for key in labelCounts:
		prop = float(labelCounts[key])/numexamples
		shannonEnt -= prop*log(prop,2)
	return shannonEnt

# -----------画图需要-------
def getNumLeafs(myTree):
	numLeafs = 0												
	firstStr = next(iter(myTree))								
	secondDict = myTree[firstStr]								
	for key in secondDict.keys():
	    if type(secondDict[key]).__name__=='dict':				
	        numLeafs += getNumLeafs(secondDict[key])
	    else:   numLeafs +=1
	return numLeafs

def getTreeDepth(myTree):
	maxDepth = 0												
	firstStr = next(iter(myTree))								
	secondDict = myTree[firstStr]								
	for key in secondDict.keys():
	    if type(secondDict[key]).__name__=='dict':				
	        thisDepth = 1 + getTreeDepth(secondDict[key])
	    else:   thisDepth = 1
	    if thisDepth > maxDepth: maxDepth = thisDepth			
	return maxDepth
# ------------------------------

# +++++++++++++++++++++ 画图工具,因为已经有成熟的工具类可以直接引用,就不专门讲 +++++++++++++
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
	arrow_args = dict(arrowstyle="<-")											
	createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',	
		xytext=centerPt, textcoords='axes fraction',
		va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plotMidText(cntrPt, parentPt, txtString):
	xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]																
	yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
	createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
	decisionNode = dict(boxstyle="sawtooth", fc="0.8")										
	leafNode = dict(boxstyle="round4", fc="0.8")											
	numLeafs = getNumLeafs(myTree)  														
	depth = getTreeDepth(myTree)															
	firstStr = next(iter(myTree))																								
	cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)	
	plotMidText(cntrPt, parentPt, nodeTxt)													
	plotNode(firstStr, cntrPt, parentPt, decisionNode)										
	secondDict = myTree[firstStr]															
	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD										
	for key in secondDict.keys():								
		if type(secondDict[key]).__name__=='dict':											
			plotTree(secondDict[key],cntrPt,str(key))        								
		else:																														
			plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
			plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
			plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
	plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(inTree):
	fig = plt.figure(1, facecolor='white')													#创建fig
	fig.clf()																				#清空fig
	axprops = dict(xticks=[], yticks=[])
	createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    							#去掉x、y轴
	plotTree.totalW = float(getNumLeafs(inTree))											#获取决策树叶结点数目
	plotTree.totalD = float(getTreeDepth(inTree))											#获取决策树层数
	plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;								#x偏移
	plotTree(inTree, (0.5,1.0), '')															#绘制决策树
	plt.show()
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++



if __name__ == '__main__':
	dataset, labels = createDataSet()
	featLabels = []
	myTree = createTree(dataset,labels,featLabels)
	createPlot(myTree)
						

6、总结

        决策树基本上是机器学习算法中数学公式推导最少最简单的一种了,需要的是理解公式的含义还有一些剪枝的概念,以及数据的处理方式,可以说是最好全面掌握的算法了。

        最后,觉得有帮助或者有点收获的话,帮忙点个赞吧!

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐