LibTorch之图像分类
LibTorch之图像分类。
·
L i b T o r c h 之图像分类 LibTorch之图像分类 LibTorch之图像分类
数据集地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip
LibTorch之全连接层(torch::nn::Linear)使用
卷积层
LibTorch实现MLP(多层感知机)
LibTorch实现LeNet
训练
#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h>
#include <filesystem>
using namespace std;
namespace fs = std::filesystem;
vector<pair<string, int>> get_imgs_labels(const std::string& data_dir, map<string, int> dict_label)
{
// 1.定义标签
//map<string, int> dict_label;
//dict_label.insert(pair<string, int>("ants", 0));
//dict_label.insert(pair<string, int>("bees", 1));
// 2.定义存储图像路径和标签的vector
vector<pair<string, int>> data_info;
// 3.读取图像和对应label放入data_info
// 遍历字典,读取图像路径和对应label
for (map<string, int>::iterator it = dict_label.begin(); it != dict_label.end(); it++)
{
// 遍历目录查找
for (const auto& file_path : fs::directory_iterator(data_dir))
{
if (file_path.path().filename() == it->first) {
// 遍历所有图像路径
for (const auto& img_path : fs::directory_iterator(data_dir + "\\" + it->first))
{
//std::cout << img_path.path() << std::endl;
data_info.push_back(pair<string, int>(img_path.path().string(), it->second));
}
}
//std::cout << entry.path() << std::endl;
}
//printVector(data_info);
}
return data_info;
}
/// <summary>
/// 数据集处理模块类
/// </summary>
class MyDataset :public torch::data::Dataset<MyDataset> {
private:
vector<pair<string, int>> data_info;
torch::Tensor imgs, labels;
public:
// 构造器:一般用于确定数据集和预处理形式
MyDataset(const std::string& data_dir,std::map<string,int> dict_label);
// get_item数据处理:对读取的数据进行预处理
torch::data::Example<> get(size_t index) override;
// 返回数据数量
torch::optional<size_t> size() const override {
return data_info.size();
};
};
/// <summary>
/// 根据数据集路径和对应的标签列表,配对训练数据
/// </summary>
/// <param name="data_dir"></param>
/// <param name="dict_label"></param>
MyDataset::MyDataset(const std::string& data_dir, std::map<string, int> dict_label) {
// 获取训练数据
data_info = get_imgs_labels(data_dir, dict_label);
}
/// <summary>
/// 对数据进行预处理,并返回成对的实例Example{data,label}
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
torch::data::Example<> MyDataset::get(size_t index)
{
// 获取图像路径
auto img_path = data_info[index].first;
// 确定label
auto label = data_info[index].second;
// opencv根据图像路径读取图像
auto image = cv::imread(img_path);
cout << image.size() << endl;
//获取通道数
int channels = image.channels();
cout<<"channels:" <<channels << endl;
// resize图像大小
cv::resize(image, image, cv::Size(224, 224));
// mat转tensor
auto input_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }).to(torch::kFloat32) / 225.0;
cout << input_tensor.sizes() << endl;
// int转tensor
torch::Tensor label_tensor = torch::tensor(label);
return {input_tensor,label_tensor };
}
/// <summary>
/// LeNet实现类
/// </summary>
class LeNet :public torch::nn::Module {
public:
// 构造器
LeNet(int num_classes, int num_linear);
// 前向传播
torch::Tensor forward(torch::Tensor x);
private:
// 具体实现放到构造器实现中
torch::nn::Conv2d conv1{ nullptr };
torch::nn::Conv2d conv2{ nullptr };
torch::nn::Linear fc1{ nullptr };
torch::nn::Linear fc2{ nullptr };
torch::nn::Linear fc3{ nullptr };
};
LeNet::LeNet(int num_classes, int num_linear)
{
conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5)));
conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5)));
fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128)));
fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32)));
fc3 = register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));
}
torch::Tensor LeNet::forward(torch::Tensor x)
{
auto out = torch::relu(conv1->forward(x));
out = torch::max_pool2d(out, 2);
out = torch::relu(conv2(out));
out = torch::max_pool2d(out, 2);
out = out.view({ 1, -1 });
out = torch::relu(fc1(out));
out = torch::relu(fc2(out));
out = fc3(out);
return out;
}
int main()
{
try
{
map<string, int> dict_label;
dict_label.insert(pair<string, int>("ants", 0));
dict_label.insert(pair<string, int>("bees", 1));
// 设置dataset
auto dataset_train = MyDataset("D:\\dataset\\hymenoptera_data\\train", dict_label).map(torch::data::transforms::Stack<>());
// batchszie
int batchSize = 1;
// 设置dataloader
auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset_train), batchSize);
// 打印
//for (auto& batch : * dataLoader) {
// auto data = batch.data;
// auto target = batch.target;
// std::cout << data.sizes() << std::endl;
// //std::cout << data.max() << std::endl;
// //std::cout << data << std::endl;
// std::cout << target << std::endl;
// int ssss;
// cin >> ssss;
//}
//auto net = LeNet(5, 44944);
std::shared_ptr<LeNet> net = std::make_shared<LeNet>(2, 44944);
// 优化器
torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);
for (size_t epoch = 1; epoch <= 10; ++epoch) {
size_t batch_index = 0;
// 遍历数据集
for (auto& batch : *dataLoader) {
// 梯度清零.
optimizer.zero_grad();
// 前向传播
torch::Tensor prediction = net->forward(batch.data);
cout << "prediction:" << prediction << endl;
cout << "target:" << batch.target << endl;
// 计算损失
torch::Tensor loss = torch::nll_loss(prediction, batch.target);
cout <<"loss:" << loss << endl;
// 反向传播
loss.backward();
// 更新梯度
optimizer.step();
// 间隔 x batch 进行loss打印和模型保存
if (++batch_index % 20 == 0) {
std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
<< " | Loss: " << loss << std::endl;
// 保存模型
torch::save(net, "net.pt");
cout << net->parameters() << endl;
}
}
}
}
catch (const std::exception& e)
{
// step5:打印报错
cout << e.what() << endl;
}
return 0;
}
更多推荐
所有评论(0)