1. ID3

1.1 要用到的包

都已上传免费资源。

  • jdom.jar
  • weka.jar

1.2 Java代码

package machinelearning.decisiontree;

import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;

import org.jdom2.Document;
import org.jdom2.Element;
import org.jdom2.output.Format;
import org.jdom2.output.XMLOutputter;

import weka.core.Instance;
import weka.core.Instances;

/**
 * *****************************************************
 * 使用说明                                              *
 * 1. 把下面的数据据,复制到文本文件,放到D盘,重命名为weather.arff *
 * 2. 自己的数据,安装下面的格式写就行。                        *
 * *****************************************************
 */

/**
@relation weather
@attribute Outlook {Sunny, Overcast, Rain}
@attribute Temperature {Hot, Mild, Cool}
@attribute Humidity {High, Normal, Low}
@attribute Windy {FALSE, TRUE}
@attribute Play {N, P}
@data
Sunny,Hot,High,FALSE,N
Sunny,Hot,High,TRUE,N
Overcast,Hot,High,FALSE,P
Rain,Mild,High,FALSE,P
Rain,Cool,Normal,FALSE,P
Rain,Cool,Normal,TRUE,N
Overcast,Cool,Normal,TRUE,P
Sunny,Mild,High,FALSE,N
Sunny,Cool,Normal,FALSE,P
Rain,Mild,Normal,FALSE,P
Sunny,Mild,Normal,TRUE,P
Overcast,Mild,High,TRUE,P
Overcast,Hot,Normal,FALSE,P
Rain,Mild,High,TRUE,N
*/

/**
 * Iterative Dichotomiser 3 (迭代二叉树3代)
 * 
 * ID3.java Created on: 2022年4月2日
 * 
 * @author NoBug
 *
 */
public class ID3 {

	/**
	 * datas
	 */
	Instances dataset;

	/**
	 * Is this dataset pure. eg: if Only one result, pure = 1; else pure = 0;
	 */
	boolean pure;

	/**
	 * The number of classes. For binary classification it is 2.
	 */
	int numClasses;

	/**
	 * Available instances. A line for datas.
	 */
	int[] availableInstances;

	/**
	 * Available attributes. Candidate attributes.
	 */
	int[] availableAttributes;

	/**
	 * The selected attribute. (关键:选择划分属性)
	 */
	int splitAttribute;

	/**
	 * The children nodes.
	 */
	ID3[] children;

	/**
	 * Decision result
	 */
	int result;

	/**
	 * The prediction, including queried and predicted results.
	 */
	int[] predicts;

	/**
	 * Small block cannot be split further.
	 */
	static int smallBlockThreshold = 3;

	public ID3() {
	}

	/**
	 * Constructor of a arff file
	 * 
	 * @param fileURL
	 */
	public ID3(String fileURL) {
		dataset = null;
		try {
			FileReader fileReader = new FileReader(fileURL);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception e) {
			System.out.println("Cannot read the file: " + fileURL + "\r\n" + e);
			System.exit(0);
		} // Of try

		dataset.setClassIndex(dataset.numAttributes() - 1);
		numClasses = dataset.classAttribute().numValues();

		availableInstances = new int[dataset.numInstances()];
		for (int i = 0; i < availableInstances.length; i++) {
			availableInstances[i] = i;
		} // Of for i
		availableAttributes = new int[dataset.numAttributes() - 1];
		for (int i = 0; i < availableAttributes.length; i++) {
			availableAttributes[i] = i;
		} // Of for i

		// Initialize.
		children = null;
		result = getMajorityClassIndex(availableInstances);
		pure = isPure(availableInstances);
	}// first constructor

	/**
	 * Constructor of self set datas
	 * 
	 * @param paraDataset
	 * @param paraAvailableInstances
	 * @param paraAvailableAttributes
	 */
	public ID3(Instances paraDataset, int[] paraAvailableInstances, int[] paraAvailableAttributes) {
		dataset = paraDataset;
		availableInstances = paraAvailableInstances;
		availableAttributes = paraAvailableAttributes;

		children = null;
		result = getMajorityClassIndex(availableInstances);
		pure = isPure(availableInstances);
	}// second constructor

	/**
	 * Is pure of available instances.
	 * 
	 * @param paraBlock is the part of instances.
	 * @return if only one result return true.
	 */
	public boolean isPure(int[] paraBlock) {
		pure = true;

		for (int i = 1; i < paraBlock.length; i++) {
			if (dataset.instance(paraBlock[i]).classValue() != dataset.instance(paraBlock[0]).classValue()) {
				pure = false;
				break;
			} // Of if
		} // Of for i

		return pure;
	}// ifPure

	/**
	 * Get class
	 * 
	 * @param paraBlock is the part of instances.
	 * @return
	 */
	public int getMajorityClassIndex(int[] paraBlock) {
		int[] classes = new int[dataset.numClasses()]; // numClasses = 2, so classes.length = 2
		for (int i = 0; i < paraBlock.length; i++) {
			classes[(int) dataset.instance(paraBlock[i]).classValue()]++;
		} // Of for i

		int index = -1;
		int tempCount = -1;
		for (int i = 0; i < classes.length; i++) {
			if (tempCount < classes[i]) {
				index = i;
				tempCount = classes[i];
			} // Of if
		} // Of for i

		return index;
	}// Of getMajorityClass

	/**
	 * Select best attribute.
	 * 
	 * Use attribute's (max gain) or (min entropy).最大化信息增益, 与最小化条件信息熵, 两者是等价的
	 * 
	 * @return The index of best attribute.
	 */
	public int selectBestAttribute() {
		splitAttribute = -1;
		double tempMinEntropy = 10000;
		double tempEntropy;

		for (int i = 0; i < availableAttributes.length; i++) {
			tempEntropy = getAttributeEnt(availableAttributes[i]);
			if (tempMinEntropy > tempEntropy) {
				tempMinEntropy = tempEntropy;
				splitAttribute = availableAttributes[i];
			} // Of if
		} // Of for i

		return splitAttribute;
	}// selectBestAttribute

	/**
	 * Get attribute's max gain.
	 * 
	 * Or Get attribute's min entropy.
	 * 
	 * @param paraAttribute
	 * @return
	 */
	public double getAttributeEnt(int paraAttribute) {
		// Step 1. Statistics.
		int tempNumClasses = dataset.numClasses();
		int tempNumValues = dataset.attribute(paraAttribute).numValues();
		int tempNumInstances = availableInstances.length;
		double[] tempValueCounts = new double[tempNumValues];
		double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];

		int tempClass, tempValue;
		for (int i = 0; i < tempNumInstances; i++) {
			tempClass = (int) dataset.instance(availableInstances[i]).classValue();
			tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
			tempValueCounts[tempValue]++;
			tempCountMatrix[tempValue][tempClass]++;
		} // Of for i

		// Step 2.
		double resultEntropy = 0;
		double tempEntropy, tempFraction;
		for (int i = 0; i < tempNumValues; i++) {
			if (tempValueCounts[i] == 0) {
				continue;
			} // Of if
			tempEntropy = 0;
			for (int j = 0; j < tempNumClasses; j++) {
				tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
				if (tempFraction == 0) {
					continue;
				} // Of if
				tempEntropy += -tempFraction * Math.log(tempFraction);
			} // Of for j
			resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
		} // Of for i

		return resultEntropy;
	}// getAttributeEnt

	/**
	 * Split the data according to the given attribute.
	 * 
	 * @param paraSplitAttribute
	 * @return
	 */
	public int[][] splitData(int paraSplitAttribute) {
		int numValues = dataset.attribute(paraSplitAttribute).numValues();
		int[][] resultBlocks = new int[numValues][];
		int[] tempSizes = new int[numValues];

		// First scan to count the size of each block.
		int tempValue;
		for (int i = 0; i < availableInstances.length; i++) {
			tempValue = (int) dataset.instance(availableInstances[i]).value(paraSplitAttribute);
			tempSizes[tempValue]++;
		} // Of for i

		// Allocate space.
		for (int i = 0; i < numValues; i++) {
			resultBlocks[i] = new int[tempSizes[i]];
		} // Of for i

		// Second scan to fill.
		Arrays.fill(tempSizes, 0);
		for (int i = 0; i < availableInstances.length; i++) {
			tempValue = (int) dataset.instance(availableInstances[i]).value(paraSplitAttribute);
			// Copy data.
			resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
			tempSizes[tempValue]++;
		} // Of for i

		return resultBlocks;
	}// Of splitData

	/**
	 * Build decision tree.
	 */
	public void buildDecisionTree() {
		if (isPure(availableInstances)) {
			return;
		} // Of if
		if (availableInstances.length <= smallBlockThreshold) {
			return;
		} // Of if

		selectBestAttribute();
		int[][] tempSubBlocks = splitData(splitAttribute);
		children = new ID3[tempSubBlocks.length];

		// Construct the remaining attribute set.
		int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
		for (int i = 0; i < availableAttributes.length; i++) {
			if (availableAttributes[i] < splitAttribute) {
				tempRemainingAttributes[i] = availableAttributes[i];
			} else if (availableAttributes[i] > splitAttribute) {
				tempRemainingAttributes[i - 1] = availableAttributes[i];
			} // Of if
		} // Of for i

		// Construct children.
		for (int i = 0; i < children.length; i++) {
			if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
				children[i] = null;
				continue;
			} else {
				children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);

				children[i].buildDecisionTree();
			} // Of if
		} // Of for i
	}// buildDecisionTree

	/**
	 * Classify an instance.
	 * 
	 * @param paraInstance
	 * @return The prediction.
	 */
	public int classify(Instance paraInstance) {
		if (children == null) {
			return result;
		} // Of if

		ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
		if (tempChild == null) {
			return result;
		} // Of if

		return tempChild.classify(paraInstance);
	}// classify

	/**
	 * Get accuracy.
	 * 
	 * @param paraDataset
	 * @return
	 */
	public double getAccuracy(Instances paraDataset) {
		int correctNum = 0;
		for (int i = 0; i < paraDataset.numInstances(); i++) {
			if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
				correctNum++;
			} // Of i
		} // Of for i

		return (double) correctNum / paraDataset.numInstances();
	}// getAccuracy

	/**
	 * Self Accuracy.
	 * 
	 * @return
	 */
	public double selfAccuracy() {
		return getAccuracy(dataset);
	}// Of selfTest

	/**
	 * The tree structure.
	 * 
	 * @return 决策路径
	 */
	public String toString() {
		String decisionPath = "";
		String attributeName = dataset.attribute(splitAttribute).name();
		if (children == null) {
			decisionPath += "class = " + result;
		} else {
			for (int i = 0; i < children.length; i++) {
				if (children[i] == null) {
					decisionPath += attributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":"
							+ "class = " + result + "\n";
				} else {
					decisionPath += attributeName + " = " + dataset.attribute(splitAttribute).value(i) + ":"
							+ children[i] + "\n";
				} // Of if
			} // Of for i
		} // Of if

		return decisionPath;
	}// Of toString

	public void writeToXML(String fileURL) {

		// init
		Document xmldoc = new Document();
		Element root = new Element("ID3");
		xmldoc.setRootElement(root);
		root.setAttribute("accuracy", "" + selfAccuracy());

		// write tree

		// write decision tree to xml file.
		try {
			FileWriter writer = new FileWriter(fileURL);
			XMLOutputter outputter = new XMLOutputter(Format.getPrettyFormat().setEncoding("gb2312"));

			outputter.output(xmldoc, writer); // 输出到文件
			outputter.output(xmldoc, System.out); // 输出到控制台

			writer.close();
		} catch (IOException e) {
			System.out.println(e);
		}
	}

	/**
	 * machain learning
	 */
	public void learning() {
		ID3 demo = new ID3("D:/weather.arff");
		demo.buildDecisionTree();

		System.out.println("The decision path is:\n\n" + demo);
		System.out.println("The accuracy is: " + demo.selfAccuracy() * 100 + "%");

		System.out.println("\nThe tree is:\n");
		demo.writeToXML("D:/dt.xml");
	}// learing

	/**
	 * eg: Sunny,Hot,High,FALSE,?
	 */
	public void predict() {

	}

	public static void main(String[] args) {
		ID3 id3 = new ID3();
		id3.learning();
	}// Of main

}// ID3

1.3 输出样例

The decision path is:

Outlook = Sunny:Humidity = High:class = 0
Humidity = Normal:class = 1
Humidity = Low:class = 0

Outlook = Overcast:class = 1
Outlook = Rain:Windy = FALSE:class = 1
Windy = TRUE:class = 0


The accuracy is: 100.0%

The tree is:

<?xml version="1.0" encoding="gb2312"?>
<ID3 accuracy="1.0" />

2. 未完,待续

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐