机器学习-决策树的编写(1):ID3简单实现决策树并可视化

前言

本文为作者的第一次尝试编写,没有用到机器学习库,代码都是手撸的,数据都是手动输入,也没有剪枝……但实现效果还不错,便发出来做一次记录,可供初学者参考。

决策树的基本概念

决策树,简单而言,就是从树的根节点开始,在内部节点处需要做判断,直到到达一个叶子节点处,得到决策结果。

决策树的判别标准:
1.如果节点对应数据子集中的样本基本属于同一个类别,则无需对节点的数据子集做进一步划分,否则就要对该节点的数据子集做进一步划分,生成新的判别标准;
2.如果新判别标准能够基本上把结点上不同类别的数据分离开,使得每个子结点都是类别比较单一的数据,那么该判别标准就是一个好规则,否则需重新选取判别标准。

决策树的构造一般分为三个部分:特征选择(本文使用ID3,即使用“信息增益准则”),决策树生成,剪枝(本文未使用)

信息熵:度量样本集合“纯度”最常用的一种指标。假定当前样本集合 D 中第 k 类样本所占的比例为p_i ,则D的信息熵定义为:
在这里插入图片描述

其值越大,则表明D中所包含样本标签取值越杂乱。

信息增益:以信息熵为基础,计算当前划分对信息熵所造成的变化。以属性a对数据集D进行划分所获得的信息增益为:
在这里插入图片描述
ID3选择信息增益最大的为“划分属性”。

所用数据集

由于是第一次“实验性质”的编写,数据集直接用了书上的“西瓜数据集”:
在这里插入图片描述

python文件结构

在这里插入图片描述

代码

init.py
from Tree.Comentropy import *
from Tree.CreateTree import *
from Tree.Draw import *
def main():
    dataLabel=['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']#数据标签
    dataSet=[ ['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
              ['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
              ['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
              ['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
              ['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
              ['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
              ['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
              ['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
              ['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
              ['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
              ['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
              ['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
              ['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
              ['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
              ['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
              ['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
              ['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否'] ]#数据集
    print("信息熵为{}".format(calComentropy(dataSet,6)))
    print("色泽的信息增益为{}".format(calCondComentropy(dataSet,0,6)))
    root=Node('根',6)
    root.index=dataSet
    root.label=dataLabel
    #print(root.index)
    tree=BinaryTree(root)
    tree.createWholeTree(root,6)
    tree.lengthTree(root,1)
    print("树的高度是",tree.high)
    tree.leafTree(root)
    print("树的叶子结点个数", tree.leafNum, end='')
    #createPlot()
    createPlot.ax1 = plt.subplot(111, frameon=False)  # frameon表示是否绘制坐标轴矩形
    tree.drawTree(root, (0,0), (0,0))
    plt.axis('off')# 去掉坐标轴以及刻度
    plt.show()
if __name__ == '__main__':
    main()
Comentropy.py
from math import log
def calComentropy(dataSet, place):#计算信息熵
    lenDataSet=len(dataSet)
    dict={}#决策元素出现了几次
    ans=0.0
    for data in dataSet:
        needJudge=data[place]#取决策元素
        if needJudge not in dict:
            dict[needJudge]=0#第一次出现,将其加入字典
        dict[needJudge]+=1#出现次数
    for key in dict:
        p = float(dict[key]) / float(lenDataSet)
        ans-=p*log(p,2)
    return ans
def calCondComentropy(dataSet,placeCon,placeRoot):#计算信息增量
    lenDataSet=len(dataSet)
    ent=[]#第i个元素:三个值,分别为条件名,出现次数,熵值
    sizeCond=0#一共有几种该属性条件
    dict={}#条件元素出现了几次
    for data in dataSet:
        cond=data[placeCon]
        if cond not in dict:
            dict[cond]=0#第一次出现该条件,将其加入字典
            sizeCond+=1
        dict[cond]+=1
    dictRep = {}#用于下面的循环,判断某条件是否出现过
    for data in dataSet:
        cond=data[placeCon]
        if cond not in dictRep:
            dataSetCond=[]# 将符合该条件的新建数据集存储
            for dataCond in dataSet:
                if dataCond[placeCon]==cond:
                    dataSetCond.append(dataCond)
            #print(cond,' ',dict[cond],' ',calComentropy(dataSetCond,placeRoot))
            ent.append([cond,dict[cond],calComentropy(dataSetCond,placeRoot)])
            dictRep[cond] = 0#标记
    cmentropyNum=0.0#计算信息增量的减数
    ans=0.0
    for entCond in ent:
        cmentropyNum+=entCond[2]*float(entCond[1])/float(lenDataSet)
    ans=calComentropy(dataSet,placeRoot)-cmentropyNum
    return ans
def judgeSame(dataSet,placeRoot):#判断结果是否相同
    dict = {}#几种结果
    for data in dataSet:
        if data[placeRoot] not in dict:
            dict[data[placeRoot]]=1#加入字典
    if len(dict)==1:
        return 1
    else:
        return len(dict)
def cmp(dataSet,sum,placeRoot):#比较信息增量
    cmpNum=0.0#存储最大值
    ans=-1#存储最大值所在属性的位置
    condNum=0
    while condNum<sum:
        if condNum==placeRoot:
            continue
        else:
            if calCondComentropy(dataSet,condNum,placeRoot)>cmpNum:
                cmpNum=calCondComentropy(dataSet,condNum,placeRoot)
                ans=condNum
        condNum+=1
    return ans
Draw.py
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round", fc="0.8")  # 决策节点的属性。boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8")  # 决策树叶子节点的属性
arrow_args = dict(arrowstyle="<-")  # 箭头的属性
plt.rcParams['font.sans-serif']=['SimHei']#用来正常显示中文
plt.rcParams['axes.unicode_minus']=False#用来正常显示负号
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    # nodeTxt为要显示的文本,xy是箭头尖的坐标,xytest是注释内容的中心坐标
    # xycoords和textcoords是坐标xy与xytext的说明(按轴坐标),若textcoords=None,则默认textcoords与xycoords相同,若都未设置,默认为data
    # va/ha设置节点框中文字的位置,va为纵向取值为(u'top', u'bottom', u'center', u'baseline'),ha为横向取值为(u'center', u'right', u'left')
def plotMidText(cntrPt, parentPt, txtString):   #  在两个节点之间的线上写上字
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)  # text() 的使用
def createPlot():
    #    fig = plt.figure(1, facecolor = 'white') #创建一个画布,背景为白色
    #    fig.clf() #画布清空
    # ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
    plt.show()
CreateTree.py
from Tree.Comentropy import *
from Tree.Draw import *
class Node:
    def __init__(self,number,condNum):
        self.name=''
        self.number=number
        self.condNum=condNum
        self.index=[]
        self.label=[]
        self.child=[]
        self.childNum=0
class BinaryTree(object):
    def __init__(self,node):
        self.root=node
        self.high=0
        self.leafNum=0
        self.cenX=0#root中心坐标
        self.cenY=0
        self.interX=0#间隔坐标
        self.interY=0
        self.lastX={}#记录某行当前最后一个x坐标,避免图中出现重叠
    def createWholeTree(self,node,placeRoot):
        if judgeSame(node.index,placeRoot)==1:
            node.name=node.index[0][placeRoot]
            return
        else:
            needNum=cmp(node.index,node.condNum,placeRoot)
            node.name = node.label[needNum]
            '''条件个数调试
            dict = {}  # 条件元素
            sizeCond = 0  # 一共有几种该属性条件
            for data in node.index:
                cond = data[needNum]
                if cond not in dict:
                    dict[cond] = 0  # 第一次出现该条件,将其加入字典
                    sizeCond += 1
                dict[cond] += 1
            print(cond,sizeCond)
            '''
            dictRep = {}  # 用于下面的循环,判断某条件是否出现过
            for data in node.index:
                cond = data[needNum]
                if cond not in dictRep:
                    dataSetCond = []  # 将符合该条件的新建数据集存储
                    for dataCond in node.index:
                        if dataCond[needNum] == cond:
                            dataSetCond.append(dataCond)
                    dictRep[cond] = 0  # 标记
                    newNode=Node(cond,node.condNum) #建立子结点
                    newNode.index=dataSetCond
                    newNode.label=node.label
                    node.child.append(newNode)
                    node.childNum+=1
                    self.createWholeTree(newNode, placeRoot)
            return
    def lengthTree(self,node,length):
        if node.childNum==0:
            self.high=max(length,self.high)
            return
        else:
            for child in node.child:
                self.lengthTree(child,length+1)
    def leafTree(self,node):
        #print(node.number)
        if node.childNum==0:
            self.leafNum+=1
            return
        else:
            for child in node.child:
                self.leafTree(child)
    def doMath(self):
        self.interY=(float)(self.high)/(float)(self.high+1)
        self.interX=0.5
        self.cenY=1.0/(float)(self.high+1)
        self.cenX=1.0/(float)(self.leafNum+1)

    def drawTree(self,node,fatherPT,sonPT):
        self.doMath()
        if fatherPT==(0,0) and sonPT==(0,0):#根结点
            newPT=(self.interX,self.interY)
            plotNode(node.name, newPT, newPT, decisionNode)
            sumNumber=node.childNum
            number=1
            for child in node.child:
                locationX=number-(sumNumber+1)/2
                self.drawTree(child,newPT,(self.interX+locationX*self.cenX,self.interY-self.cenY))
                number+=1
        else:
            #print(node.number)
            plotMidText(sonPT, fatherPT, node.number)
            if node.name=='是' or node.name=='否':
                plotNode(node.name, sonPT, fatherPT, leafNode)
                self.lastX[sonPT[1]]=sonPT[0]
                return
            self.lastX[sonPT[1]] = sonPT[0]
            plotNode(node.name, sonPT,fatherPT, decisionNode)
            sumNumber = node.childNum
            number = 1
            for child in node.child:
                locationX = number - (sumNumber + 1) / 2
                if sonPT[1] - self.cenY not in self.lastX:
                    self.drawTree(child, sonPT, (sonPT[0] + locationX * self.cenX, sonPT[1] - self.cenY))
                elif self.lastX[sonPT[1] - self.cenY]+ self.cenX<sonPT[0] + locationX * self.cenX:
                    self.drawTree(child, sonPT, (sonPT[0] + locationX * self.cenX, sonPT[1] - self.cenY))
                else:
                    self.drawTree(child, sonPT, (self.lastX[sonPT[1] - self.cenY] + self.cenX, sonPT[1] - self.cenY))
                number += 1
        return

运行结果展示

在这里插入图片描述
在这里插入图片描述

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐