【Project】YOLO分类模型训练
YOLO初学习
·
效果演示
训练模型
运行train.py,第一次运行需要注册swanlab账号,申请apikey
进入可视化页面后可以看到当前训练的一些详细的参数
文件 | 用途 | 性能可靠性 | 适用场景 |
---|---|---|---|
best.pt | 部署、测试 | 高(验证集最优) | 实际应用、模型评估 |
last.pt | 恢复训练、调试 | 可能不稳定 | 继续训练、检查训练收敛性 |
运行模型
访问网址
环境准备
依赖
名称 | 版本 |
---|---|
python | 3.8 |
torch | cpu1.8.2 |
torchaudio | cpu0.8.2 |
torchvision | cpu0.9.2 |
目录结构
准备好图片和标注文件
(learn_model) Fruit_Model> tree
卷 worker 的文件夹 PATH 列表
卷序列号为 AACB-A67E
D:.
├─datasets
│ ├─images
│ │ ├─test
│ │ ├─train
│ │ └─val
│ └─labels
│ ├─test
│ ├─train
│ └─val
定义数据集的结构和类别信息
D:\Environment\python_home\CV机器视觉\Fruit_Model\datasets\data.yaml
path: D:\Environment\python_home\CV机器视觉\Fruit_Model\datasets\
train: images/train
test: images/test
val: images/val
# class names
names:
0: apple
1: banana
2: not apple
3: not banana
4: orange
训练模型
train.py
from ultralytics import YOLO
from swanlab.integration.ultralytics import add_swanlab_callback
import swanlab
def main():
swanlab.init(project="Fruit_model", experiment_name="YOLOv8n_optimized_fruit",)
model = YOLO("yolov8n.pt")
add_swanlab_callback(model)
# 将下面的路径替换成你的绝对路径
model.train(data="D:\\Environment\\python_home\\CV机器视觉\\Fruit_Model\\datasets\\data.yaml", epochs=5, imgsz=416, batch=16)
if __name__ == "__main__":
main()
启动模型APP
config.py
import os
import pymysql
# 项目根目录
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 模型配置
MODEL_PATH = os.path.join(BASE_DIR, 'yolo11n.pt')
# 上传文件配置
UPLOAD_DIR = os.path.join(BASE_DIR, 'uploads')
os.makedirs(UPLOAD_DIR, exist_ok=True)
# 数据库配置
DB_CONFIG = {
"host": "192.168.100.153",
"user": "root",
"password": "123456",
"database": "fruit_db",
"charset": "utf8mb4",
"cursorclass": pymysql.cursors.DictCursor,
"connect_timeout": 10 # 连接超时时间(秒)
}
# 类别映射表
CLASS_NAME_MAPPING = {
"apple": "苹果",
"banana": "香蕉",
"not apple": "不是苹果",
"not banana": "不是香蕉",
"orange": "橘子"
}
# 模型默认参数
DEFAULT_CONF_THRESHOLD = 0.25
DEFAULT_IOU_THRESHOLD = 0.45
DEFAULT_IMGSZ = 416
# Gradio 服务器配置
GRADIO_SERVER_CONFIG = {
"share": True,
"server_name": "127.0.0.1",
"server_port": 7868
}
main.py
import gradio as gr
from ultralytics import YOLO
from PIL import Image
import pymysql
from datetime import datetime
import os
import time
from config import (
MODEL_PATH, UPLOAD_DIR, DB_CONFIG, CLASS_NAME_MAPPING,
DEFAULT_CONF_THRESHOLD, DEFAULT_IOU_THRESHOLD, DEFAULT_IMGSZ,
GRADIO_SERVER_CONFIG
)
# 重试装饰器
def retry(tries=3, delay=1):
def decorator(func):
def wrapper(*args, **kwargs):
attempt = 0
while attempt < tries:
try:
return func(*args, **kwargs)
except Exception as e:
attempt += 1
if attempt == tries:
raise e
time.sleep(delay)
return wrapper
return decorator
# 健壮的数据库连接
@retry(tries=3, delay=2)
def get_db_connection():
try:
connection = pymysql.connect(**DB_CONFIG)
return connection
except Exception as e:
print(f"数据库连接失败: {e}")
raise
def save_to_database(image_path, detections):
"""将检测结果保存到数据库"""
try:
with get_db_connection() as connection:
with connection.cursor() as cursor:
# 插入检测记录
insert_detection_query = """
INSERT INTO detections (image_path, upload_time)
VALUES (%s, %s)
"""
upload_time = datetime.now()
cursor.execute(insert_detection_query, (image_path, upload_time))
detection_id = cursor.lastrowid
# 批量插入检测结果
insert_result_query = """
INSERT INTO detection_results
(detection_id, english_name, chinese_name, confidence)
VALUES (%s, %s, %s, %s)
"""
result_data = [
(detection_id, english_name, chinese_name, confidence)
for english_name, chinese_name, confidence in detections
]
cursor.executemany(insert_result_query, result_data)
connection.commit()
print(f"成功保存 {len(detections)} 条检测结果到数据库")
except Exception as e:
print(f"数据库操作失败: {e}")
# 可以添加日志记录或其他恢复机制
def validate_image(image):
"""验证图像有效性"""
if image is None:
raise ValueError("请上传有效图像")
# 检查图像尺寸
width, height = image.size
if width < 10 or height < 10:
raise ValueError("图像尺寸过小")
# 检查图像模式
if image.mode not in ('RGB', 'RGBA'):
try:
image = image.convert('RGB')
except Exception as e:
raise ValueError("无法转换图像格式") from e
return image
def predict_image(image, conf_threshold=DEFAULT_CONF_THRESHOLD, iou_threshold=DEFAULT_IOU_THRESHOLD):
try:
# 验证输入
image = validate_image(image)
# 保存原始图像
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
temp_image_path = os.path.join(UPLOAD_DIR, f"{timestamp}.jpg")
# 确保文件名不冲突
counter = 1
while os.path.exists(temp_image_path):
temp_image_path = os.path.join(UPLOAD_DIR, f"{timestamp}_{counter}.jpg")
counter += 1
image.save(temp_image_path)
# 加载模型
if not hasattr(predict_image, 'model'):
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"模型文件不存在: {MODEL_PATH}")
predict_image.model = YOLO(MODEL_PATH)
# 模型推理
results = predict_image.model.predict(
source=image,
conf=conf_threshold,
iou=iou_threshold,
imgsz=DEFAULT_IMGSZ
)
# 处理检测结果
detections = []
class_results = []
for r in results:
# 绘制结果图像
im_array = r.plot()
result_image = Image.fromarray(im_array[..., ::-1])
# 提取检测信息
if r.boxes is not None and len(r.boxes) > 0:
for box in r.boxes:
class_id = int(box.cls)
confidence = float(box.conf)
# 验证类别ID有效性
if class_id >= len(predict_image.model.names):
continue
english_name = predict_image.model.names[class_id]
chinese_name = CLASS_NAME_MAPPING.get(english_name, english_name)
detections.append((english_name, chinese_name, confidence))
class_results.append(f"{chinese_name} ({confidence:.2f})")
# 保存到数据库
if detections:
save_to_database(temp_image_path, detections)
# 格式化结果
unique_results = ", ".join(sorted(set(class_results), key=class_results.index)) if class_results else "未检测到目标"
print(f"检测结果: {unique_results}")
return result_image, unique_results
except FileNotFoundError as e:
print(f"文件错误: {e}")
return None, f"文件错误: {str(e)}"
except ValueError as e:
print(f"输入验证失败: {e}")
return None, f"输入错误: {str(e)}"
except Exception as e:
print(f"处理图片时出错: {e}")
return None, f"处理图片时出错: {str(e)}"
# 创建Gradio界面
demo = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(type="pil", label="上传图片"),
gr.Slider(0, 1, DEFAULT_CONF_THRESHOLD, label="置信度阈值"),
gr.Slider(0, 1, DEFAULT_IOU_THRESHOLD, label="交并比阈值"),
],
outputs=[
gr.Image(type="pil", label="识别结果"),
gr.Textbox(label="检测到的水果分类及置信度")
],
title="人工智能识别系统",
description="请上传一张水果的图像进行识别",
examples=[
["examples/apple.jpg", 0.3, 0.45],
["examples/banana.jpg", 0.25, 0.5]
] if os.path.exists("examples") else None
)
# 启动应用
if __name__ == "__main__":
print(
f"启动Gradio应用,访问地址: http://{GRADIO_SERVER_CONFIG['server_name']}:{GRADIO_SERVER_CONFIG['server_port']}")
demo.launch(**GRADIO_SERVER_CONFIG)
更多推荐
所有评论(0)