机器学习-决策树的编写(1):ID3简单实现决策树并可视化
机器学习-决策树的编写(1):ID3简单实现决策树并可视化文章目录机器学习-决策树的编写(1):ID3简单实现决策树并可视化前言决策树的基本概念所用数据集python文件结构代码__init__.pyComentropy.pyDraw.pyCreateTree.py运行结果展示前言本文为作者的第一次尝试编写,没有用到机器学习库,代码都是手撸的,数据都是手动输入,也没有剪枝……但实现效果还不错,便发
机器学习-决策树的编写(1):ID3简单实现决策树并可视化
文章目录
前言
本文为作者的第一次尝试编写,没有用到机器学习库,代码都是手撸的,数据都是手动输入,也没有剪枝……但实现效果还不错,便发出来做一次记录,可供初学者参考。
决策树的基本概念
决策树,简单而言,就是从树的根节点开始,在内部节点处需要做判断,直到到达一个叶子节点处,得到决策结果。
决策树的判别标准:
1.如果节点对应数据子集中的样本基本属于同一个类别,则无需对节点的数据子集做进一步划分,否则就要对该节点的数据子集做进一步划分,生成新的判别标准;
2.如果新判别标准能够基本上把结点上不同类别的数据分离开,使得每个子结点都是类别比较单一的数据,那么该判别标准就是一个好规则,否则需重新选取判别标准。
决策树的构造一般分为三个部分:特征选择(本文使用ID3,即使用“信息增益准则”),决策树生成,剪枝(本文未使用)
信息熵:度量样本集合“纯度”最常用的一种指标。假定当前样本集合 D 中第 k 类样本所占的比例为p_i ,则D的信息熵定义为:
其值越大,则表明D中所包含样本标签取值越杂乱。
信息增益:以信息熵为基础,计算当前划分对信息熵所造成的变化。以属性a对数据集D进行划分所获得的信息增益为:
ID3选择信息增益最大的为“划分属性”。
所用数据集
由于是第一次“实验性质”的编写,数据集直接用了书上的“西瓜数据集”:
python文件结构
代码
init.py
from Tree.Comentropy import *
from Tree.CreateTree import *
from Tree.Draw import *
def main():
dataLabel=['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']#数据标签
dataSet=[ ['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否'] ]#数据集
print("信息熵为{}".format(calComentropy(dataSet,6)))
print("色泽的信息增益为{}".format(calCondComentropy(dataSet,0,6)))
root=Node('根',6)
root.index=dataSet
root.label=dataLabel
#print(root.index)
tree=BinaryTree(root)
tree.createWholeTree(root,6)
tree.lengthTree(root,1)
print("树的高度是",tree.high)
tree.leafTree(root)
print("树的叶子结点个数", tree.leafNum, end='')
#createPlot()
createPlot.ax1 = plt.subplot(111, frameon=False) # frameon表示是否绘制坐标轴矩形
tree.drawTree(root, (0,0), (0,0))
plt.axis('off')# 去掉坐标轴以及刻度
plt.show()
if __name__ == '__main__':
main()
Comentropy.py
from math import log
def calComentropy(dataSet, place):#计算信息熵
lenDataSet=len(dataSet)
dict={}#决策元素出现了几次
ans=0.0
for data in dataSet:
needJudge=data[place]#取决策元素
if needJudge not in dict:
dict[needJudge]=0#第一次出现,将其加入字典
dict[needJudge]+=1#出现次数
for key in dict:
p = float(dict[key]) / float(lenDataSet)
ans-=p*log(p,2)
return ans
def calCondComentropy(dataSet,placeCon,placeRoot):#计算信息增量
lenDataSet=len(dataSet)
ent=[]#第i个元素:三个值,分别为条件名,出现次数,熵值
sizeCond=0#一共有几种该属性条件
dict={}#条件元素出现了几次
for data in dataSet:
cond=data[placeCon]
if cond not in dict:
dict[cond]=0#第一次出现该条件,将其加入字典
sizeCond+=1
dict[cond]+=1
dictRep = {}#用于下面的循环,判断某条件是否出现过
for data in dataSet:
cond=data[placeCon]
if cond not in dictRep:
dataSetCond=[]# 将符合该条件的新建数据集存储
for dataCond in dataSet:
if dataCond[placeCon]==cond:
dataSetCond.append(dataCond)
#print(cond,' ',dict[cond],' ',calComentropy(dataSetCond,placeRoot))
ent.append([cond,dict[cond],calComentropy(dataSetCond,placeRoot)])
dictRep[cond] = 0#标记
cmentropyNum=0.0#计算信息增量的减数
ans=0.0
for entCond in ent:
cmentropyNum+=entCond[2]*float(entCond[1])/float(lenDataSet)
ans=calComentropy(dataSet,placeRoot)-cmentropyNum
return ans
def judgeSame(dataSet,placeRoot):#判断结果是否相同
dict = {}#几种结果
for data in dataSet:
if data[placeRoot] not in dict:
dict[data[placeRoot]]=1#加入字典
if len(dict)==1:
return 1
else:
return len(dict)
def cmp(dataSet,sum,placeRoot):#比较信息增量
cmpNum=0.0#存储最大值
ans=-1#存储最大值所在属性的位置
condNum=0
while condNum<sum:
if condNum==placeRoot:
continue
else:
if calCondComentropy(dataSet,condNum,placeRoot)>cmpNum:
cmpNum=calCondComentropy(dataSet,condNum,placeRoot)
ans=condNum
condNum+=1
return ans
Draw.py
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="round", fc="0.8") # 决策节点的属性。boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
leafNode = dict(boxstyle="round4", fc="0.8") # 决策树叶子节点的属性
arrow_args = dict(arrowstyle="<-") # 箭头的属性
plt.rcParams['font.sans-serif']=['SimHei']#用来正常显示中文
plt.rcParams['axes.unicode_minus']=False#用来正常显示负号
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# nodeTxt为要显示的文本,xy是箭头尖的坐标,xytest是注释内容的中心坐标
# xycoords和textcoords是坐标xy与xytext的说明(按轴坐标),若textcoords=None,则默认textcoords与xycoords相同,若都未设置,默认为data
# va/ha设置节点框中文字的位置,va为纵向取值为(u'top', u'bottom', u'center', u'baseline'),ha为横向取值为(u'center', u'right', u'left')
def plotMidText(cntrPt, parentPt, txtString): # 在两个节点之间的线上写上字
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString) # text() 的使用
def createPlot():
# fig = plt.figure(1, facecolor = 'white') #创建一个画布,背景为白色
# fig.clf() #画布清空
# ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
plt.show()
CreateTree.py
from Tree.Comentropy import *
from Tree.Draw import *
class Node:
def __init__(self,number,condNum):
self.name=''
self.number=number
self.condNum=condNum
self.index=[]
self.label=[]
self.child=[]
self.childNum=0
class BinaryTree(object):
def __init__(self,node):
self.root=node
self.high=0
self.leafNum=0
self.cenX=0#root中心坐标
self.cenY=0
self.interX=0#间隔坐标
self.interY=0
self.lastX={}#记录某行当前最后一个x坐标,避免图中出现重叠
def createWholeTree(self,node,placeRoot):
if judgeSame(node.index,placeRoot)==1:
node.name=node.index[0][placeRoot]
return
else:
needNum=cmp(node.index,node.condNum,placeRoot)
node.name = node.label[needNum]
'''条件个数调试
dict = {} # 条件元素
sizeCond = 0 # 一共有几种该属性条件
for data in node.index:
cond = data[needNum]
if cond not in dict:
dict[cond] = 0 # 第一次出现该条件,将其加入字典
sizeCond += 1
dict[cond] += 1
print(cond,sizeCond)
'''
dictRep = {} # 用于下面的循环,判断某条件是否出现过
for data in node.index:
cond = data[needNum]
if cond not in dictRep:
dataSetCond = [] # 将符合该条件的新建数据集存储
for dataCond in node.index:
if dataCond[needNum] == cond:
dataSetCond.append(dataCond)
dictRep[cond] = 0 # 标记
newNode=Node(cond,node.condNum) #建立子结点
newNode.index=dataSetCond
newNode.label=node.label
node.child.append(newNode)
node.childNum+=1
self.createWholeTree(newNode, placeRoot)
return
def lengthTree(self,node,length):
if node.childNum==0:
self.high=max(length,self.high)
return
else:
for child in node.child:
self.lengthTree(child,length+1)
def leafTree(self,node):
#print(node.number)
if node.childNum==0:
self.leafNum+=1
return
else:
for child in node.child:
self.leafTree(child)
def doMath(self):
self.interY=(float)(self.high)/(float)(self.high+1)
self.interX=0.5
self.cenY=1.0/(float)(self.high+1)
self.cenX=1.0/(float)(self.leafNum+1)
def drawTree(self,node,fatherPT,sonPT):
self.doMath()
if fatherPT==(0,0) and sonPT==(0,0):#根结点
newPT=(self.interX,self.interY)
plotNode(node.name, newPT, newPT, decisionNode)
sumNumber=node.childNum
number=1
for child in node.child:
locationX=number-(sumNumber+1)/2
self.drawTree(child,newPT,(self.interX+locationX*self.cenX,self.interY-self.cenY))
number+=1
else:
#print(node.number)
plotMidText(sonPT, fatherPT, node.number)
if node.name=='是' or node.name=='否':
plotNode(node.name, sonPT, fatherPT, leafNode)
self.lastX[sonPT[1]]=sonPT[0]
return
self.lastX[sonPT[1]] = sonPT[0]
plotNode(node.name, sonPT,fatherPT, decisionNode)
sumNumber = node.childNum
number = 1
for child in node.child:
locationX = number - (sumNumber + 1) / 2
if sonPT[1] - self.cenY not in self.lastX:
self.drawTree(child, sonPT, (sonPT[0] + locationX * self.cenX, sonPT[1] - self.cenY))
elif self.lastX[sonPT[1] - self.cenY]+ self.cenX<sonPT[0] + locationX * self.cenX:
self.drawTree(child, sonPT, (sonPT[0] + locationX * self.cenX, sonPT[1] - self.cenY))
else:
self.drawTree(child, sonPT, (self.lastX[sonPT[1] - self.cenY] + self.cenX, sonPT[1] - self.cenY))
number += 1
return
运行结果展示
更多推荐
所有评论(0)