Java OpenCV 人工智能01.2 机器学习 支持向量机 SVM
Java OpenCV-4.X 人工智能01.2 机器学习 支持向量机 SVM1 Python TensorFlow 2.6 获取 MNIST 数据1.1 获取 MNIST 数据1.2 检查 MNIST 数据2 Python 将npz数据保存为txt3 Java 获取数据并使用SVM训练4 Python 测试SVM准确度1 Python TensorFlow 2.6 获取 MNIST 数据1.1
·
Java OpenCV 人工智能01.2 机器学习 支持向量机 SVM
1 OpenCV 机器学习算法
算法 | 解释 |
---|---|
1、Normal Bayes Classifier (贝叶斯分类) | 原理:基于贝叶斯定理计算类别的条件概率,预测样本属于哪个类别,对于给定的输入特征向量X,计算每个类别Y的后验概率P(Y |
2、K-Nearest Neighbour Classifier (K-近邻算法) | 原理:根据输入样本与训练样本的距离,预测输入样本的类别,对于给定的输入样本,找到K个最近邻训练样本,然后根据这K个样本的类别信息,采用投票或概率加权的方式预测输入样本的类别。 优点:算法简单,容易理解和实现。对异常值和噪声数据也有一定鲁棒性。 缺点:计算量大,需要存储所有训练数据,当训练数据量大时,预测速度会变慢。 适用场景:图像分类、语音识别等。 |
3、SVM (支持向量机) | 原理:寻找超平面,使得不同类别的样本间隔最大,对于给定的训练数据,找到一个超平面w^T*x+b=0,使得同类样本落在超平面的同一侧,异类样本落在超平面的不同侧,且间隔最大。优点:泛化性能好,能很好地处理高维数据。 缺点:对参数调优比较敏感,计算量大。 适用场景:图像分类、文本分类、生物信息等。 |
4、EM (期望最大化算法) | 原理:通过迭代的方式估计概率模型的参数,适用于存在隐变量的情况,对于含有隐变量的概率模型,先猜测初始参数,然后交替进行两步: E步:计算隐变量的期望值。M步:用期望值最大化参数。直到参数收敛为止,最终得到模型的最大似然估计。 优点:可以处理缺失数据,对初始值不太敏感。 缺点:收敛速度可能较慢,容易陷入局部最优。 适用场景:聚类分析、主题模型等。 |
5、Decision Tree (决策树) | 原理:通过递归的方式构建分类或回归模型,形成树状结构,从根节点开始,根据某个特征的取值进行二分或多分,直到叶节点,每个叶节点表示一个类别。 优点:模型简单易懂,对异常值和缺失数据具有一定的鲁棒性。 缺点:容易过拟合,泛化性能可能较差。 适用场景:医疗诊断、客户信用评估等。 |
6、Random Trees Classifier (随机森林) | 原理:集成多棵决策树,通过随机选取特征和样本的方式训练,综合多棵树的预测结果,构建多棵不相关的决策树,在训练每棵树时,随机选择部分特征和部分样本。 优点:准确性高,对噪声和缺失数据具有较强的鲁棒性。 缺点:模型复杂度高,解释性较差。 适用场景:图像分类、生物信息等。 |
7、Boosted Tree Classifier (Boost 树算法) | 原理:通过迭代的方式训练多棵弱决策树,最终将它们组合成一个强分类器,在每一轮训练中,根据上一轮的预测误差,训练一棵新的弱决策树,并给它分配一个权重。 优点:准确性高,对异常值和噪声数据具有一定鲁棒性。 缺点:对异常值和噪声数据敏感,容易过拟合。 适用场景:图像分类、广告点击预测等。 |
8、Stochastic Gradient Descent SVM Classifier (SGD SVM) | 原理:使用随机梯度下降法优化 SVM 的目标函数,对于标准SVM的凸二次规划问题,采用随机梯度下降法进行优化。 优点:训练速度快,可以处理大规模数据。 缺点:对参数调优比较敏感,性能可能略低于传统 SVM。 适用场景:大规模数据分类任务。 |
9、ANN (人工神经网络) | 原理:通过模拟生物神经网络的结构和工作机制来构建机器学习模型,将输入特征通过一系列的神经元和权重连接,经过非线性变换,最终得到输出。 优点:对复杂非线性问题有较强的拟合能力,可以自动学习特征。 缺点:训练过程复杂,容易过拟合,需要大量训练数据。 适用场景:图像识别、语音识别、自然语言处理等。 |
2 OpenCV SVM 手写数字识别训练
package com.xu.opencv.ml;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import org.opencv.core.TermCriteria;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.ml.Ml;
import org.opencv.ml.SVM;
import java.io.File;
import java.util.LinkedList;
import java.util.List;
/**
* SVM 测试
*
* @author hyacinth
* @date 2024年4月6日22点36分
* @since V1.0.0.0
*/
public class SvmTest {
/**
* 训练图片
*/
private static final String IMG = "lib/opencv/img/1.png";
/**
* 训练结果文件
*/
private static final String FILE = "lib/opencv/file/num.xml";
static {
String os = System.getProperty("os.name");
String type = System.getProperty("sun.arch.data.model");
if (os.toUpperCase().contains("WINDOWS")) {
File lib;
if (type.endsWith("64")) {
lib = new File("lib\\opencv\\x64\\" + System.mapLibraryName("opencv_java490"));
} else {
lib = new File("lib\\opencv\\x86\\" + System.mapLibraryName("opencv_java490"));
}
System.load(lib.getAbsolutePath());
}
}
public static void main(String[] args) {
//train(IMG, FILE);
predict(IMG, FILE);
}
/**
* 预测
*
* @param image 训练图片
* @param file 训练文件
* @date 2024年4月6日22点36分
* @since V1.0.0.0
*/
private static void predict(String image, String file) {
// 加载训练文件
SVM svm = SVM.load(file);
// 预测图片
Mat src = Imgcodecs.imread(image, Imgcodecs.IMREAD_GRAYSCALE);
Rect rect = new Rect(0, 700, 20, 20);
Mat dst = new Mat(src, rect);
// 预测
Mat img = new Mat(1, (int) dst.total() * dst.channels(), CvType.CV_32F);
img.put(0, 0, pixel(dst));
// 预测
Mat result = new Mat();
svm.predict(img, result);
// 输出预测结果
System.out.println("预测结果: " + (int) result.get(0, 0)[0]);
}
/**
* 训练
*
* @param image 训练图片
* @param file 训练文件
* @date 2024年4月6日22点36分
* @since V1.0.0.0
*/
private static void train(String image, String file) {
// 准备训练数据
List<Mat> train = new LinkedList<>();
List<Integer> label = new LinkedList<>();
// 读取训练数据
load(image, train, label);
// 创建 SVM 分类器
org.opencv.ml.SVM svm = org.opencv.ml.SVM.create();
svm.setC(1);
svm.setP(0);
svm.setNu(0);
svm.setCoef0(0);
svm.setGamma(1);
svm.setDegree(0);
svm.setType(org.opencv.ml.SVM.C_SVC);
svm.setKernel(org.opencv.ml.SVM.LINEAR);
svm.setTermCriteria(new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 1000, 0));
// 训练 SVM 模型
Mat samples = new Mat(train.size(), (int) train.get(0).total(), CvType.CV_32FC1);
Mat responses = new Mat(label.size(), 1, CvType.CV_32SC1);
for (int i = 0; i < train.size(); i++) {
samples.put(i, 0, pixel(train.get(i)));
responses.put(i, 0, label.get(i));
}
svm.train(samples, Ml.ROW_SAMPLE, responses);
// 保存模型
svm.save(file);
}
/**
* 图片转数组
*
* @param img 图片
* @return 图像数组
* @date 2024年4月6日22点36分
* @since V1.0.0.0
*/
private static float[] pixel(Mat img) {
Mat mat = new Mat();
img.convertTo(mat, CvType.CV_32F);
int count = 0;
float[] pixel = new float[(int) (mat.total() * mat.channels())];
for (int i = 0, row = (int) mat.size().height; i < row; i++) {
for (int j = 0, col = (int) mat.size().width; j < col; j++) {
pixel[count] = (float) mat.get(i, j)[0];
count++;
}
}
return pixel;
}
/**
* 加载文件
*
* @param image 训练图片
* @param train 训练文件
* @param label 训练标签
* @date 2024年4月6日22点36分
* @since V1.0.0.0
*/
public static void load(String image, List<Mat> train, List<Integer> label) {
Mat src = Imgcodecs.imread(image, Imgcodecs.IMREAD_GRAYSCALE);
for (int i = 0; i <= 49; i++) {
for (int j = 0; j <= 99; j++) {
label.add((int) Math.floor(i / 5));
Rect rect = new Rect(j * 20, i * 20, 20, 20);
train.add(new Mat(src, rect));
}
}
}
}
3 OpenCV SVM 手写数字识别预测
D:\Environment\Java\jdk-21\bin\java.exe "-javaagent:D:\IDE\IntelliJ IDEA Community Edition 2023.1.1\lib\idea_rt.jar=11182:D:\IDE\IntelliJ IDEA Community Edition 2023.1.1\bin" -Dfile.encoding=UTF-8 -Dsun.stdout.encoding=UTF-8 -Dsun.stderr.encoding=UTF-8 -classpath D:\SourceCode\Intellij\OpenCV\target\classes;D:\SourceCode\Intellij\OpenCV\lib\opencv\opencv-490.jar com.xu.opencv.ml.SvmTest
预测结果: 7
Process finished with exit code 0
更多推荐
所有评论(0)