resnet模型 图像分类 图像识别 权重导出 pt模型转化为onnx格式

最近在处理边缘计算产品模型部署时遇到一些麻烦,由于产品不支持yolov分类模型部署,于是选择resnet网络用pytorch重新训练,同时需要将训练好的权重导出,并转化为onnx通用格式,参考了网上大佬的代码后整合了下,在此记录一下,日后用~~~

以下是数据集目录结构(非常简单),将需要分类的图片存放到对应目录里,train与val比例为8:2,very easy~~~
在这里插入图片描述
完整模型训练代码如下,包括数据加载+模型训练+保存权重+格式转化:

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
 
warnings.filterwarnings("ignore")
 
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 5  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 20  # 训练的轮次
traindataset = datasets.ImageFolder(root='./my_datasets/train/', transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
]))
 
testdataset = datasets.ImageFolder(root='./my_datasets/val/', transform=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
]))
 
classes = testdataset.classes
#print(classes)
 
model = models.resnet18(pretrained=preteain)
if preteain == True:
    for param in model.parameters():
        param.requires_grad = False
model.fc = nn.Linear(in_features=512, out_features=n_classes, bias=True)
model = model.to(device)
 
 
def train_model(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    for idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(dim=1)
        total_corrects += torch.sum(preds.eq(labels))
        total_loss += loss.item() * inputs.size(0)
        total += labels.size(0)
    total_loss = total_loss / total
    acc = 100 * total_corrects / total
    print("echo:%4d, 损失loss:%.5f, 准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
    return total_loss, acc
 
 
def test_model(model, test_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    accuracy_tmp = -1 
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds.eq(labels))
 
        loss = total_loss / total
        accuracy = 100 * total_corrects / total
        torch.save(model.state_dict(),'last.pt')

        if accuracy > accuracy_tmp :
            accuracy_tmp = accuracy
            torch.save(model.state_dict(),'best.pt')

        print("echo:%4d, 损失loss:%.5f, 准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
        return loss, accuracy
    
 
 
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(traindataset, batch_size=32, shuffle=True)
test_loader = DataLoader(testdataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):
    loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)
    loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)
 
classes = testdataset.classes
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

input_tensor = torch.randn(1,3,224,224)

#将pt格式转为onnx格式
torch.onnx.export(model,input_tensor,"last.onnx")
 
path = './lifecycle/val/yourpath/37.png'  # 测试图片路径
model.eval()
img = Image.open(path)
img_p = transform(img).unsqueeze(0).to(device)
output = model(img_p)
pred = output.argmax(dim=1).item()
plt.imshow(img)
plt.show()
p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
print('该图像预测类别为:', classes[pred])

执行代码:

python torch_resnet.py

输出效果:
在这里插入图片描述

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐