1 OpenCV 机器学习算法

算法 解释
1、Normal Bayes Classifier (贝叶斯分类) 原理:基于贝叶斯定理计算类别的条件概率,预测样本属于哪个类别,对于给定的输入特征向量X,计算每个类别Y的后验概率P(Y
2、K-Nearest Neighbour Classifier (K-近邻算法) 原理:根据输入样本与训练样本的距离,预测输入样本的类别,对于给定的输入样本,找到K个最近邻训练样本,然后根据这K个样本的类别信息,采用投票或概率加权的方式预测输入样本的类别。
优点:算法简单,容易理解和实现。对异常值和噪声数据也有一定鲁棒性。
缺点:计算量大,需要存储所有训练数据,当训练数据量大时,预测速度会变慢。
适用场景:图像分类、语音识别等。
3、SVM (支持向量机) 原理:寻找超平面,使得不同类别的样本间隔最大,对于给定的训练数据,找到一个超平面w^T*x+b=0,使得同类样本落在超平面的同一侧,异类样本落在超平面的不同侧,且间隔最大。优点:泛化性能好,能很好地处理高维数据。
缺点:对参数调优比较敏感,计算量大。
适用场景:图像分类、文本分类、生物信息等。
4、EM (期望最大化算法) 原理:通过迭代的方式估计概率模型的参数,适用于存在隐变量的情况,对于含有隐变量的概率模型,先猜测初始参数,然后交替进行两步: E步:计算隐变量的期望值。M步:用期望值最大化参数。直到参数收敛为止,最终得到模型的最大似然估计。
优点:可以处理缺失数据,对初始值不太敏感。
缺点:收敛速度可能较慢,容易陷入局部最优。
适用场景:聚类分析、主题模型等。
5、Decision Tree (决策树) 原理:通过递归的方式构建分类或回归模型,形成树状结构,从根节点开始,根据某个特征的取值进行二分或多分,直到叶节点,每个叶节点表示一个类别。
优点:模型简单易懂,对异常值和缺失数据具有一定的鲁棒性。
缺点:容易过拟合,泛化性能可能较差。
适用场景:医疗诊断、客户信用评估等。
6、Random Trees Classifier (随机森林) 原理:集成多棵决策树,通过随机选取特征和样本的方式训练,综合多棵树的预测结果,构建多棵不相关的决策树,在训练每棵树时,随机选择部分特征和部分样本。
优点:准确性高,对噪声和缺失数据具有较强的鲁棒性。
缺点:模型复杂度高,解释性较差。
适用场景:图像分类、生物信息等。
7、Boosted Tree Classifier (Boost 树算法) 原理:通过迭代的方式训练多棵弱决策树,最终将它们组合成一个强分类器,在每一轮训练中,根据上一轮的预测误差,训练一棵新的弱决策树,并给它分配一个权重。
优点:准确性高,对异常值和噪声数据具有一定鲁棒性。
缺点:对异常值和噪声数据敏感,容易过拟合。
适用场景:图像分类、广告点击预测等。
8、Stochastic Gradient Descent SVM Classifier (SGD SVM) 原理:使用随机梯度下降法优化 SVM 的目标函数,对于标准SVM的凸二次规划问题,采用随机梯度下降法进行优化。
优点:训练速度快,可以处理大规模数据。
缺点:对参数调优比较敏感,性能可能略低于传统 SVM。
适用场景:大规模数据分类任务。
9、ANN (人工神经网络) 原理:通过模拟生物神经网络的结构和工作机制来构建机器学习模型,将输入特征通过一系列的神经元和权重连接,经过非线性变换,最终得到输出。
优点:对复杂非线性问题有较强的拟合能力,可以自动学习特征。
缺点:训练过程复杂,容易过拟合,需要大量训练数据。
适用场景:图像识别、语音识别、自然语言处理等。

2 OpenCV SVM 手写数字识别训练

手写数字

package com.xu.opencv.ml;

import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import org.opencv.core.TermCriteria;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.ml.Ml;
import org.opencv.ml.SVM;

import java.io.File;
import java.util.LinkedList;
import java.util.List;

/**
 * SVM 测试
 *
 * @author hyacinth
 * @date 2024年4月6日22点36分
 * @since V1.0.0.0
 */
public class SvmTest {

    /**
     * 训练图片
     */
    private static final String IMG = "lib/opencv/img/1.png";

    /**
     * 训练结果文件
     */
    private static final String FILE = "lib/opencv/file/num.xml";

    static {
        String os = System.getProperty("os.name");
        String type = System.getProperty("sun.arch.data.model");
        if (os.toUpperCase().contains("WINDOWS")) {
            File lib;
            if (type.endsWith("64")) {
                lib = new File("lib\\opencv\\x64\\" + System.mapLibraryName("opencv_java490"));
            } else {
                lib = new File("lib\\opencv\\x86\\" + System.mapLibraryName("opencv_java490"));
            }
            System.load(lib.getAbsolutePath());
        }
    }

    public static void main(String[] args) {
        //train(IMG, FILE);
        predict(IMG, FILE);
    }

    /**
     * 预测
     *
     * @param image 训练图片
     * @param file  训练文件
     * @date 2024年4月6日22点36分
     * @since V1.0.0.0
     */
    private static void predict(String image, String file) {
        // 加载训练文件
        SVM svm = SVM.load(file);

        // 预测图片
        Mat src = Imgcodecs.imread(image, Imgcodecs.IMREAD_GRAYSCALE);
        Rect rect = new Rect(0, 700, 20, 20);
        Mat dst = new Mat(src, rect);

        // 预测
        Mat img = new Mat(1, (int) dst.total() * dst.channels(), CvType.CV_32F);
        img.put(0, 0, pixel(dst));

        // 预测
        Mat result = new Mat();
        svm.predict(img, result);

        // 输出预测结果
        System.out.println("预测结果: " + (int) result.get(0, 0)[0]);
    }

    /**
     * 训练
     *
     * @param image 训练图片
     * @param file  训练文件
     * @date 2024年4月6日22点36分
     * @since V1.0.0.0
     */
    private static void train(String image, String file) {
        // 准备训练数据
        List<Mat> train = new LinkedList<>();
        List<Integer> label = new LinkedList<>();

        // 读取训练数据
        load(image, train, label);

        // 创建 SVM 分类器
        org.opencv.ml.SVM svm = org.opencv.ml.SVM.create();
        svm.setC(1);
        svm.setP(0);
        svm.setNu(0);
        svm.setCoef0(0);
        svm.setGamma(1);
        svm.setDegree(0);
        svm.setType(org.opencv.ml.SVM.C_SVC);
        svm.setKernel(org.opencv.ml.SVM.LINEAR);
        svm.setTermCriteria(new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 1000, 0));

        // 训练 SVM 模型
        Mat samples = new Mat(train.size(), (int) train.get(0).total(), CvType.CV_32FC1);
        Mat responses = new Mat(label.size(), 1, CvType.CV_32SC1);
        for (int i = 0; i < train.size(); i++) {
            samples.put(i, 0, pixel(train.get(i)));
            responses.put(i, 0, label.get(i));
        }
        svm.train(samples, Ml.ROW_SAMPLE, responses);

        // 保存模型
        svm.save(file);
    }

    /**
     * 图片转数组
     *
     * @param img 图片
     * @return 图像数组
     * @date 2024年4月6日22点36分
     * @since V1.0.0.0
     */
    private static float[] pixel(Mat img) {
        Mat mat = new Mat();
        img.convertTo(mat, CvType.CV_32F);

        int count = 0;
        float[] pixel = new float[(int) (mat.total() * mat.channels())];
        for (int i = 0, row = (int) mat.size().height; i < row; i++) {
            for (int j = 0, col = (int) mat.size().width; j < col; j++) {
                pixel[count] = (float) mat.get(i, j)[0];
                count++;
            }
        }

        return pixel;
    }

    /**
     * 加载文件
     *
     * @param image 训练图片
     * @param train 训练文件
     * @param label 训练标签
     * @date 2024年4月6日22点36分
     * @since V1.0.0.0
     */
    public static void load(String image, List<Mat> train, List<Integer> label) {
        Mat src = Imgcodecs.imread(image, Imgcodecs.IMREAD_GRAYSCALE);
        for (int i = 0; i <= 49; i++) {
            for (int j = 0; j <= 99; j++) {
                label.add((int) Math.floor(i / 5));
                Rect rect = new Rect(j * 20, i * 20, 20, 20);
                train.add(new Mat(src, rect));
            }
        }
    }

}

3 OpenCV SVM 手写数字识别预测

D:\Environment\Java\jdk-21\bin\java.exe "-javaagent:D:\IDE\IntelliJ IDEA Community Edition 2023.1.1\lib\idea_rt.jar=11182:D:\IDE\IntelliJ IDEA Community Edition 2023.1.1\bin" -Dfile.encoding=UTF-8 -Dsun.stdout.encoding=UTF-8 -Dsun.stderr.encoding=UTF-8 -classpath D:\SourceCode\Intellij\OpenCV\target\classes;D:\SourceCode\Intellij\OpenCV\lib\opencv\opencv-490.jar com.xu.opencv.ml.SvmTest
预测结果: 7

Process finished with exit code 0

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐