机器学习6——决策树代码
文章目录1. ID31.1 要用到的包1.2 Java代码1.3 输出样例2. 未完,待续1. ID31.1 要用到的包都已上传免费资源。jdom.jarweka.jar1.2 Java代码package machinelearning.decisiontree;import java.io.FileReader;import java.io.FileWriter;import java.io.I
·
1. ID3
1.1 要用到的包
都已上传免费资源。
- jdom.jar
- weka.jar
1.2 Java代码
package machinelearning.decisiontree;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import org.jdom2.Document;
import org.jdom2.Element;
import org.jdom2.output.Format;
import org.jdom2.output.XMLOutputter;
import weka.core.Instance;
import weka.core.Instances;
/**
* *****************************************************
* 使用说明 *
* 1. 把下面的数据据,复制到文本文件,放到D盘,重命名为weather.arff *
* 2. 自己的数据,安装下面的格式写就行。 *
* *****************************************************
*/
/**
@relation weather
@attribute Outlook {Sunny, Overcast, Rain}
@attribute Temperature {Hot, Mild, Cool}
@attribute Humidity {High, Normal, Low}
@attribute Windy {FALSE, TRUE}
@attribute Play {N, P}
@data
Sunny,Hot,High,FALSE,N
Sunny,Hot,High,TRUE,N
Overcast,Hot,High,FALSE,P
Rain,Mild,High,FALSE,P
Rain,Cool,Normal,FALSE,P
Rain,Cool,Normal,TRUE,N
Overcast,Cool,Normal,TRUE,P
Sunny,Mild,High,FALSE,N
Sunny,Cool,Normal,FALSE,P
Rain,Mild,Normal,FALSE,P
Sunny,Mild,Normal,TRUE,P
Overcast,Mild,High,TRUE,P
Overcast,Hot,Normal,FALSE,P
Rain,Mild,High,TRUE,N
*/
/**
* Iterative Dichotomiser 3 (迭代二叉树3代)
*
* ID3.java Created on: 2022年4月2日
*
* @author NoBug
*
*/
public class ID3 {
/**
* datas
*/
Instances dataset;
/**
* Is this dataset pure. eg: if Only one result, pure = 1; else pure = 0;
*/
boolean pure;
/**
* The number of classes. For binary classification it is 2.
*/
int numClasses;
/**
* Available instances. A line for datas.
*/
int[] availableInstances;
/**
* Available attributes. Candidate attributes.
*/
int[] availableAttributes;
/**
* The selected attribute. (关键:选择划分属性)
*/
int splitAttribute;
/**
* The children nodes.
*/
ID3[] children;
/**
* Decision result
*/
int result;
/**
* The prediction, including queried and predicted results.
*/
int[] predicts;
/**
* Small block cannot be split further.
*/
static int smallBlockThreshold = 3;
public ID3() {
}
/**
* Constructor of a arff file
*
* @param fileURL
*/
public ID3(String fileURL) {
dataset = null;
try {
FileReader fileReader = new FileReader(fileURL);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception e) {
System.out.println("Cannot read the file: " + fileURL + "\r\n" + e);
System.exit(0);
} // Of try
dataset.setClassIndex(dataset.numAttributes() - 1);
numClasses = dataset.classAttribute().numValues();
availableInstances = new int[dataset.numInstances()];
for (int i = 0; i < availableInstances.length; i++) {
availableInstances[i] = i;
} // Of for i
availableAttributes = new int[dataset.numAttributes() - 1];
for (int i = 0; i < availableAttributes.length; i++) {
availableAttributes[i] = i;
} // Of for i
// Initialize.
children = null;
result = getMajorityClassIndex(availableInstances);
pure = isPure(availableInstances);
}// first constructor
/**
* Constructor of self set datas
*
* @param paraDataset
* @param paraAvailableInstances
* @param paraAvailableAttributes
*/
public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
dataset = paraDataset;
availableInstances = paraAvailableInstances;
availableAttributes = paraAvailableAttributes;
children = null;
result = getMajorityClassIndex(availableInstances);
pure = isPure(availableInstances);
}// second constructor
/**
* Is pure of available instances.
*
* @param paraBlock is the part of instances.
* @return if only one result return true.
*/
public boolean isPure(int[] paraBlock) {
pure = true;
for (int i = 1; i < paraBlock.length; i++) {
if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
pure = false;
break;
} // Of if
} // Of for i
return pure;
}// ifPure
/**
* Get class
*
* @param paraBlock is the part of instances.
* @return
*/
public int getMajorityClassIndex(int[] paraBlock) {
int[] classes = new int[dataset.numClasses()]; // numClasses = 2, so classes.length = 2
for (int i = 0; i < paraBlock.length; i++) {
classes[(int) dataset.instance(paraBlock[i]).classValue()]++;
} // Of for i
int index = -1;
int tempCount = -1;
for (int i = 0; i < classes.length; i++) {
if (tempCount < classes[i]) {
index = i;
tempCount = classes[i];
} // Of if
} // Of for i
return index;
}// Of getMajorityClass
/**
* Select best attribute.
*
* Use attribute's (max gain) or (min entropy).最大化信息增益, 与最小化条件信息熵, 两者是等价的
*
* @return The index of best attribute.
*/
public int selectBestAttribute() {
splitAttribute = -1;
double tempMinEntropy = 10000;
double tempEntropy;
for (int i = 0; i < availableAttributes.length; i++) {
tempEntropy = getAttributeEnt(availableAttributes[i]);
if (tempMinEntropy > tempEntropy) {
tempMinEntropy = tempEntropy;
splitAttribute = availableAttributes[i];
} // Of if
} // Of for i
return splitAttribute;
}// selectBestAttribute
/**
* Get attribute's max gain.
*
* Or Get attribute's min entropy.
*
* @param paraAttribute
* @return
*/
public double getAttributeEnt(int paraAttribute) {
// Step 1. Statistics.
int tempNumClasses = dataset.numClasses();
int tempNumValues = dataset.attribute(paraAttribute).numValues();
int tempNumInstances = availableInstances.length;
double[] tempValueCounts = new double[tempNumValues];
double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];
int tempClass, tempValue;
for (int i = 0; i < tempNumInstances; i++) {
tempClass = (int) dataset.instance(availableInstances[i]).classValue();
tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
tempValueCounts[tempValue]++;
tempCountMatrix[tempValue][tempClass]++;
} // Of for i
// Step 2.
double resultEntropy = 0;
double tempEntropy, tempFraction;
for (int i = 0; i < tempNumValues; i++) {
if (tempValueCounts[i] == 0) {
continue;
} // Of if
tempEntropy = 0;
for (int j = 0; j < tempNumClasses; j++) {
tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
if (tempFraction == 0) {
continue;
} // Of if
tempEntropy += -tempFraction * Math.log(tempFraction);
} // Of for j
resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
} // Of for i
return resultEntropy;
}// getAttributeEnt
/**
* Split the data according to the given attribute.
*
* @param paraSplitAttribute
* @return
*/
public int[][] splitData(int paraSplitAttribute) {
int numValues = dataset.attribute(paraSplitAttribute).numValues();
int[][] resultBlocks = new int[numValues][];
int[] tempSizes = new int[numValues];
// First scan to count the size of each block.
int tempValue;
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraSplitAttribute);
tempSizes[tempValue]++;
} // Of for i
// Allocate space.
for (int i = 0; i < numValues; i++) {
resultBlocks[i] = new int[tempSizes[i]];
} // Of for i
// Second scan to fill.
Arrays.fill(tempSizes, 0);
for (int i = 0; i < availableInstances.length; i++) {
tempValue = (int) dataset.instance(availableInstances[i]).value(paraSplitAttribute);
// Copy data.
resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
tempSizes[tempValue]++;
} // Of for i
return resultBlocks;
}// Of splitData
/**
* Build decision tree.
*/
public void buildDecisionTree() {
if (isPure(availableInstances)) {
return;
} // Of if
if (availableInstances.length <= smallBlockThreshold) {
return;
} // Of if
selectBestAttribute();
int[][] tempSubBlocks = splitData(splitAttribute);
children = new ID3[tempSubBlocks.length];
// Construct the remaining attribute set.
int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
for (int i = 0; i < availableAttributes.length; i++) {
if (availableAttributes[i] < splitAttribute) {
tempRemainingAttributes[i] = availableAttributes[i];
} else if (availableAttributes[i] > splitAttribute) {
tempRemainingAttributes[i - 1] = availableAttributes[i];
} // Of if
} // Of for i
// Construct children.
for (int i = 0; i < children.length; i++) {
if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
children[i] = null;
continue;
} else {
children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);
children[i].buildDecisionTree();
} // Of if
} // Of for i
}// buildDecisionTree
/**
* Classify an instance.
*
* @param paraInstance
* @return The prediction.
*/
public int classify(Instance paraInstance) {
if (children == null) {
return result;
} // Of if
ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
if (tempChild == null) {
return result;
} // Of if
return tempChild.classify(paraInstance);
}// classify
/**
* Get accuracy.
*
* @param paraDataset
* @return
*/
public double getAccuracy(Instances paraDataset) {
int correctNum = 0;
for (int i = 0; i < paraDataset.numInstances(); i++) {
if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
correctNum++;
} // Of i
} // Of for i
return (double) correctNum / paraDataset.numInstances();
}// getAccuracy
/**
* Self Accuracy.
*
* @return
*/
public double selfAccuracy() {
return getAccuracy(dataset);
}// Of selfTest
/**
* The tree structure.
*
* @return 决策路径
*/
public String toString() {
String decisionPath = "";
String attributeName = dataset.attribute(splitAttribute).name();
if (children == null) {
decisionPath += "class = " + result;
} else {
for (int i = 0; i < children.length; i++) {
if (children[i] == null) {
decisionPath += attributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":"
+ "class = " + result + "\n";
} else {
decisionPath += attributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":"
+ children[i] + "\n";
} // Of if
} // Of for i
} // Of if
return decisionPath;
}// Of toString
public void writeToXML(String fileURL) {
// init
Document xmldoc = new Document();
Element root = new Element("ID3");
xmldoc.setRootElement(root);
root.setAttribute("accuracy", "" + selfAccuracy());
// write tree
// write decision tree to xml file.
try {
FileWriter writer = new FileWriter(fileURL);
XMLOutputter outputter = new XMLOutputter(Format.getPrettyFormat().setEncoding("gb2312"));
outputter.output(xmldoc, writer); // 输出到文件
outputter.output(xmldoc, System.out); // 输出到控制台
writer.close();
} catch (IOException e) {
System.out.println(e);
}
}
/**
* machain learning
*/
public void learning() {
ID3 demo = new ID3("D:/weather.arff");
demo.buildDecisionTree();
System.out.println("The decision path is:\n\n" + demo);
System.out.println("The accuracy is: " + demo.selfAccuracy() * 100 + "%");
System.out.println("\nThe tree is:\n");
demo.writeToXML("D:/dt.xml");
}// learing
/**
* eg: Sunny,Hot,High,FALSE,?
*/
public void predict() {
}
public static void main(String[] args) {
ID3 id3 = new ID3();
id3.learning();
}// Of main
}// ID3
1.3 输出样例
The decision path is:
Outlook = Sunny:Humidity = High:class = 0
Humidity = Normal:class = 1
Humidity = Low:class = 0
Outlook = Overcast:class = 1
Outlook = Rain:Windy = FALSE:class = 1
Windy = TRUE:class = 0
The accuracy is: 100.0%
The tree is:
<?xml version="1.0" encoding="gb2312"?>
<ID3 accuracy="1.0" />
2. 未完,待续
更多推荐
所有评论(0)