效果演示

训练模型

运行train.py,第一次运行需要注册swanlab账号,申请apikey
在这里插入图片描述
进入可视化页面后可以看到当前训练的一些详细的参数
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

文件 用途 性能可靠性 适用场景
best.pt 部署、测试 高(验证集最优) 实际应用、模型评估
last.pt 恢复训练、调试 可能不稳定 继续训练、检查训练收敛性

运行模型

在这里插入图片描述
访问网址
在这里插入图片描述
在这里插入图片描述

环境准备

依赖

名称 版本
python 3.8
torch cpu1.8.2
torchaudio cpu0.8.2
torchvision cpu0.9.2

在这里插入图片描述

目录结构

准备好图片和标注文件

(learn_model) Fruit_Model> tree
卷 worker 的文件夹 PATH 列表
卷序列号为 AACB-A67E
D:.
├─datasets
│  ├─images
│  │  ├─test
│  │  ├─train
│  │  └─val
│  └─labels
│      ├─test
│      ├─train
│      └─val

在这里插入图片描述

定义数据集的结构和类别信息

D:\Environment\python_home\CV机器视觉\Fruit_Model\datasets\data.yaml

path: D:\Environment\python_home\CV机器视觉\Fruit_Model\datasets\
train: images/train
test: images/test
val: images/val

# class names
names:
  0: apple
  1: banana
  2: not apple
  3: not banana
  4: orange

训练模型

train.py

from ultralytics import YOLO
from swanlab.integration.ultralytics import add_swanlab_callback
import swanlab

def main():
    swanlab.init(project="Fruit_model", experiment_name="YOLOv8n_optimized_fruit",)
    model = YOLO("yolov8n.pt")
    add_swanlab_callback(model)
    # 将下面的路径替换成你的绝对路径
    model.train(data="D:\\Environment\\python_home\\CV机器视觉\\Fruit_Model\\datasets\\data.yaml", epochs=5, imgsz=416, batch=16)

if __name__ == "__main__":
    main()

启动模型APP

config.py

import os
import pymysql

# 项目根目录
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# 模型配置
MODEL_PATH = os.path.join(BASE_DIR, 'yolo11n.pt')

# 上传文件配置
UPLOAD_DIR = os.path.join(BASE_DIR, 'uploads')
os.makedirs(UPLOAD_DIR, exist_ok=True)

# 数据库配置
DB_CONFIG = {
    "host": "192.168.100.153",
    "user": "root",
    "password": "123456",
    "database": "fruit_db",
    "charset": "utf8mb4",
    "cursorclass": pymysql.cursors.DictCursor,
    "connect_timeout": 10  # 连接超时时间(秒)
}

# 类别映射表
CLASS_NAME_MAPPING = {
    "apple": "苹果",
    "banana": "香蕉",
    "not apple": "不是苹果",
    "not banana": "不是香蕉",
    "orange": "橘子"
}

# 模型默认参数
DEFAULT_CONF_THRESHOLD = 0.25
DEFAULT_IOU_THRESHOLD = 0.45
DEFAULT_IMGSZ = 416

# Gradio 服务器配置
GRADIO_SERVER_CONFIG = {
    "share": True,
    "server_name": "127.0.0.1",
    "server_port": 7868
}

main.py

import gradio as gr
from ultralytics import YOLO
from PIL import Image
import pymysql
from datetime import datetime
import os
import time
from config import (
    MODEL_PATH, UPLOAD_DIR, DB_CONFIG, CLASS_NAME_MAPPING,
    DEFAULT_CONF_THRESHOLD, DEFAULT_IOU_THRESHOLD, DEFAULT_IMGSZ,
    GRADIO_SERVER_CONFIG
)


# 重试装饰器
def retry(tries=3, delay=1):
    def decorator(func):
        def wrapper(*args, **kwargs):
            attempt = 0
            while attempt < tries:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    attempt += 1
                    if attempt == tries:
                        raise e
                    time.sleep(delay)

        return wrapper

    return decorator


# 健壮的数据库连接
@retry(tries=3, delay=2)
def get_db_connection():
    try:
        connection = pymysql.connect(**DB_CONFIG)
        return connection
    except Exception as e:
        print(f"数据库连接失败: {e}")
        raise


def save_to_database(image_path, detections):
    """将检测结果保存到数据库"""
    try:
        with get_db_connection() as connection:
            with connection.cursor() as cursor:
                # 插入检测记录
                insert_detection_query = """
                INSERT INTO detections (image_path, upload_time) 
                VALUES (%s, %s)
                """
                upload_time = datetime.now()
                cursor.execute(insert_detection_query, (image_path, upload_time))
                detection_id = cursor.lastrowid

                # 批量插入检测结果
                insert_result_query = """
                INSERT INTO detection_results 
                (detection_id, english_name, chinese_name, confidence) 
                VALUES (%s, %s, %s, %s)
                """
                result_data = [
                    (detection_id, english_name, chinese_name, confidence)
                    for english_name, chinese_name, confidence in detections
                ]
                cursor.executemany(insert_result_query, result_data)

                connection.commit()
                print(f"成功保存 {len(detections)} 条检测结果到数据库")

    except Exception as e:
        print(f"数据库操作失败: {e}")
        # 可以添加日志记录或其他恢复机制


def validate_image(image):
    """验证图像有效性"""
    if image is None:
        raise ValueError("请上传有效图像")

    # 检查图像尺寸
    width, height = image.size
    if width < 10 or height < 10:
        raise ValueError("图像尺寸过小")

    # 检查图像模式
    if image.mode not in ('RGB', 'RGBA'):
        try:
            image = image.convert('RGB')
        except Exception as e:
            raise ValueError("无法转换图像格式") from e

    return image


def predict_image(image, conf_threshold=DEFAULT_CONF_THRESHOLD, iou_threshold=DEFAULT_IOU_THRESHOLD):
    try:
        # 验证输入
        image = validate_image(image)

        # 保存原始图像
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        temp_image_path = os.path.join(UPLOAD_DIR, f"{timestamp}.jpg")

        # 确保文件名不冲突
        counter = 1
        while os.path.exists(temp_image_path):
            temp_image_path = os.path.join(UPLOAD_DIR, f"{timestamp}_{counter}.jpg")
            counter += 1

        image.save(temp_image_path)

        # 加载模型
        if not hasattr(predict_image, 'model'):
            if not os.path.exists(MODEL_PATH):
                raise FileNotFoundError(f"模型文件不存在: {MODEL_PATH}")
            predict_image.model = YOLO(MODEL_PATH)

        # 模型推理
        results = predict_image.model.predict(
            source=image,
            conf=conf_threshold,
            iou=iou_threshold,
            imgsz=DEFAULT_IMGSZ
        )

        # 处理检测结果
        detections = []
        class_results = []

        for r in results:
            # 绘制结果图像
            im_array = r.plot()
            result_image = Image.fromarray(im_array[..., ::-1])

            # 提取检测信息
            if r.boxes is not None and len(r.boxes) > 0:
                for box in r.boxes:
                    class_id = int(box.cls)
                    confidence = float(box.conf)

                    # 验证类别ID有效性
                    if class_id >= len(predict_image.model.names):
                        continue

                    english_name = predict_image.model.names[class_id]
                    chinese_name = CLASS_NAME_MAPPING.get(english_name, english_name)

                    detections.append((english_name, chinese_name, confidence))
                    class_results.append(f"{chinese_name} ({confidence:.2f})")

        # 保存到数据库
        if detections:
            save_to_database(temp_image_path, detections)

        # 格式化结果
        unique_results = ", ".join(sorted(set(class_results), key=class_results.index)) if class_results else "未检测到目标"
        print(f"检测结果: {unique_results}")

        return result_image, unique_results

    except FileNotFoundError as e:
        print(f"文件错误: {e}")
        return None, f"文件错误: {str(e)}"
    except ValueError as e:
        print(f"输入验证失败: {e}")
        return None, f"输入错误: {str(e)}"
    except Exception as e:
        print(f"处理图片时出错: {e}")
        return None, f"处理图片时出错: {str(e)}"


# 创建Gradio界面
demo = gr.Interface(
    fn=predict_image,
    inputs=[
        gr.Image(type="pil", label="上传图片"),
        gr.Slider(0, 1, DEFAULT_CONF_THRESHOLD, label="置信度阈值"),
        gr.Slider(0, 1, DEFAULT_IOU_THRESHOLD, label="交并比阈值"),
    ],
    outputs=[
        gr.Image(type="pil", label="识别结果"),
        gr.Textbox(label="检测到的水果分类及置信度")
    ],
    title="人工智能识别系统",
    description="请上传一张水果的图像进行识别",
    examples=[
        ["examples/apple.jpg", 0.3, 0.45],
        ["examples/banana.jpg", 0.25, 0.5]
    ] if os.path.exists("examples") else None
)

# 启动应用
if __name__ == "__main__":
    print(
        f"启动Gradio应用,访问地址: http://{GRADIO_SERVER_CONFIG['server_name']}:{GRADIO_SERVER_CONFIG['server_port']}")
    demo.launch(**GRADIO_SERVER_CONFIG)
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐