原作者: Pytorch实现中药材(中草药)分类识别(含训练代码和数据集)
本人在原作者代码基础上进行改进,这里先给出原作者链接,以示尊重原作。

前言:基于原作者代码,经训练了mobilenet_v2、resnet18和resnet34三种模型效果后,本人选择识别准确率最高的resnet34进行识别。然通过网络爬虫爬取各类中草药功能信息。使用QT界面展示每识别一张图片输出对应的功能信息。

效果

在这里插入图片描述

resnet34

  • 直接调用pytorch预训练好的resnet34模型
from torchvision.models.resnet import model_urls
	backbone = models.resnet34(pretrained=False)

tensorboard可视化

  • 终端运行以下一行代码,点开输出的网址。
tensorboard --logdir=路径/resnet34_1.0_CrossEntropyLoss_20230524125658/log

mobilenet_v2:
在这里插入图片描述

resnet18:
在这里插入图片描述
resnet34:
在这里插入图片描述

  • 综上,resnet34的 train 准确率为99.28%, test 准确率为98.42。

网络爬虫

  • 通过搜索每一种中草药,爬取对应的功能信息。请求需间隔一定时间,不然请求会失败,可能出现阿贾克斯请求;同时在爬取时不要登录账号。爬取163类中草药需进行两次,爬到一半左右程序会报错中断,这时需将代码47行的for循环的索引进行对应的调整再继续爬取剩下的。
from bs4 import BeautifulSoup
import urllib.request
import re
import time

class ZYC():
    def __init__(self):
        # 伪装成浏览器访问,适用于拒绝爬虫的网站
        self.headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.50'}

    def gethtml(self, name):
        # 修改链接拼接方式
        full_path = "/zhongyaocai"
        params = {
            'keyword': name,
            'jc_me_name': 1,
            'me_pinyin': 1,
            'me_englishname': 1,
            'me_latinname': 1,
            'use': '',
            'hauxuechenhfen': '',
            'zhuyi': '',
            'yctype': '全部',
        }
        encoded_params = urllib.parse.urlencode(params)

        # 拼接完整的 URL,请参考修改的参数
        url = f'https://db.yaozh.com{full_path}?{encoded_params}'

        # 在创建 Request 对象时指定 headers 参数,而不是使用 encoding 参数
        self.headers['Content-Type'] = 'application/json;charset=utf-8'
        req = urllib.request.Request(url, headers=self.headers)

        req_timeout = 5
        response = urllib.request.urlopen(req, None, req_timeout)
        html = response.read().decode('utf-8')
        return html

    # 获取自己想要的内容
    def getinformation(self):
        # 读取中药名称
        words = []  # 存储txt中药材名称
        with open('class_names.txt', 'r') as class_name:
            for line in class_name:
                word = line.strip()  # strip()方法用于去除每行末尾的字符(默认是空格和换行符)
                words.append(word)  # 将当前行读到的单词添加到列表中
        for name in range(len(words)):
            html = self.gethtml(words[name])  # 传入药材名称
            reg = re.compile(r"</p>") # 利用正则表达式去除
            html = reg.sub('', html)
            reg = re.compile(r"<p>")
            html = reg.sub('', html)
            soup = BeautifulSoup(html, "html.parser") # 转换成 BeautifulSoup 对象
            Trlist = soup.find_all('tr')

            # 获取
            try:
                # # 获取标题
                # for item in Trlist[0]:  # 包含了第一页中所有的列名
                #     if item not in ['\n', '\t', ' ']:  # 检测换行符或空白符
                #         item = item.get_text(strip=True)  # 转换为纯文本格式
                #         with open("Chinese_herbal.txt", "a") as file:
                #             file.write(item + '|')  # 分隔符

                # 获取内容
                file = open("Chinese_herbal.txt", "a")
                file.write('\n')
                for te in Trlist[1]:
                    for item in te:
                        if item not in ['\n', ' ', '\s']:
                            item = item.get_text(strip=True)
                            reg = re.compile(r'\s+')  # 正则表达式匹配任意连续的空白符
                            item = reg.sub('', item)  # 替换为空
                            file.write(item + '|')
                file.close()
                print("--正在采集%s信息--" % words[name])
                time.sleep(5)  # 延迟程序执行的时间;等待5秒,然后再进行下一个网页的爬取。
            except IndexError:
                print(f"No data found for {words[name]}")


if __name__ == '__main__':
    ZYC().getinformation()

demo

代码逻辑:

导入库

import sys
sys.path.append("libs")
import argparse
from basetrainer.utils import setup_config
from pybaseutils import file_utils, image_utils
from classifier import inference
from PyQt5.QtWidgets import QApplication, QMainWindow
from ui import Ui_MainWindow
from PyQt5 import QtGui

再看主函数

if __name__ == "__main__":
    # 路径——>解析器——>预测分类——>实现
    parser = get_parser()
    # cfg配置对象;setup_config解析命令行参数(parser.parse_args():命令行参数解析器, cfg_updata使用默认值)。
    cfg = setup_config.parser_config(parser.parse_args(), cfg_updata=False)
    t = Predictor(cfg)
    # (image_dir图像目录,shuffle=False表示不要打乱目录中文件的顺序(默认为True,即打乱))
    t.image_dir_predict(cfg.image_dir, shuffle=False)

该步骤进行配置文件和训练好模型文件的调用

def get_parser():
    # 配置文件
    # config_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/config.yaml"
    # # 模型文件
    # model_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/model/best_model_116_98.4700.pth"
    # 配置文件
    config_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/config.yaml"
    # 模型文件
    model_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/model/best_model_114_98.5700.pth"

    # 待测试图片目录
    image_dir = "data/test_images"
    # 创建用于解析命令行参数的ArgumentParser对象;用于读取命令行输入,并将输入转换为其他数据类型以供程序使用。
    parser = argparse.ArgumentParser(description="Inference Argument")
    # 使用argparse库来解析命令行参数的函数。
    # ("-c", "--config_file":参数名称,分别代表短选项何长选项;
    # help=:该选项的描述信息;default=:不提供选项则使用默认值;type=:类型)
    parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
    parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
    parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
    parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
    return parser

再而进入识别的类中

class Predictor(inference.Inference):
    def __init__(self, cfg):
        super(Predictor, self).__init__(cfg)

    def predict(self, image):
        """预测类别"""
        # pred_index类别标签,pred_score置信度得分
        pred_index, pred_score = self.inference(image)
        # 将模型输出的类别pred_index使用self.label2class_name函数进行转换,返回对应的类别名称
        pred_index = self.label2class_name(pred_index)
        return pred_index, pred_score

然后进行图片的识别,以得到图片识别的类别和准确率

    def image_dir_predict(self, image_dir, vis=True, use_rgb=True, shuffle=False):
        """
        :param image_dir: list,*.txt ,图像路径或目录
        :param vis: 是否可视化
        :param use_rgb: 是否转换为RGB格式
        :param shuffle:  是否打乱顺序
        """
        pred_index_base = [] # 存储名称,穿进嵌套的类中
        formatted_string_base = [] # 存储准确率,穿进嵌套的类中
        image_list = file_utils.get_files_lists(image_dir, shuffle=shuffle) # 读取image_dir
        cnt = len(image_list) # 统计长度
        for path in image_list:
            image = image_utils.read_image_ch(path, use_rgb=use_rgb)  # 读取图像文件,对该文件进行解码并将其转换为numpy数组。
            pred_index, pred_score = self.predict(image)  # 预测类别
            value = float(pred_score[0])  # 准确率位数限制
            formatted_string = "%.5f" % value
            pred_index_base.append(pred_index[0])
            formatted_string_base.append(formatted_string)

这里嵌套QT界面的类,如果两个类合并则会出现冲突,所以在上一步将所有识别的图片名称、准确率存储进列表,然后在QT类中使用计数器原理进行调用,当按钮触发后才进入下一张图片的展示。

		# 嵌套QT界面
        if vis:  # 图像显示
            class MainWindow(QMainWindow, Ui_MainWindow):
                def __init__(self):
                    super().__init__()
                    # 设置UI
                    self.setupUi(self)
                    self.tmp = 0 # 计数器
                    self.pushButton.clicked.connect(self.change_image) # 按钮

                def change_image(self):
                    if self.tmp == cnt: # 达到最后一张照片则
                        sys.exit() # 关闭界面
                    else:
                        self.textBrowser.clear() # 清空上一次的信息
                        self.textBrowser_2.clear()
                        self.textBrowser_3.clear()
                        self.textBrowser_4.clear()
                        self.textBrowser_5.clear()
                        self.textBrowser_6.clear()
                        path = image_list[self.tmp]
                        # 将图片设置为QLabel的内容
                        self.label_2.setPixmap(QtGui.QPixmap(path)) # 图片
                        self.textBrowser.append(pred_index_base[self.tmp]) # 名称
                        self.textBrowser_2.append(str(formatted_string_base[self.tmp])) # 准确率
                        self.BinarySearch(pred_index_base[self.tmp]) # 详情
                        self.tmps() # 计数器

                def tmps(self):
                    self.tmp += 1
                # 信息输出
                def BinarySearch(self, path):
                    # 读取中药名称
                    words = []  # 存储txt中药材名称
                    with open('class_names.txt', 'r') as f:  # 读取中药名称
                        for line in f:
                            word = line.strip()  # strip()方法用于去除每行末尾的字符(默认是空格和换行符)
                            words.append(word)  # 将当前行读到的单词添加到列表中
                    # 查找
                    with open('Chinese_herbal.txt', 'r') as t:  # 读取中药信息
                        temp = t.read()
                        line_list = temp.splitlines()
                        if len(path) <= 2:  # 二分查找思想
                            for i in range(91):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])
                        else:
                            for i in range(91, 163):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])

            app = QApplication(sys.argv)
            main_win = MainWindow()
            main_win.show()
            sys.exit(app.exec_())

完整代码

import sys
sys.path.append("libs")
import argparse
from basetrainer.utils import setup_config
from pybaseutils import file_utils, image_utils
from classifier import inference
from PyQt5.QtWidgets import QApplication, QMainWindow
from ui import Ui_MainWindow
from PyQt5 import QtGui

class Predictor(inference.Inference):
    def __init__(self, cfg):
        super(Predictor, self).__init__(cfg)

    def predict(self, image):
        """预测类别"""
        # pred_index类别标签,pred_score置信度得分
        pred_index, pred_score = self.inference(image)
        # 将模型输出的类别pred_index使用self.label2class_name函数进行转换,返回对应的类别名称
        pred_index = self.label2class_name(pred_index)
        return pred_index, pred_score

    def image_dir_predict(self, image_dir, vis=True, use_rgb=True, shuffle=False):
        """
        :param image_dir: list,*.txt ,图像路径或目录
        :param vis: 是否可视化
        :param use_rgb: 是否转换为RGB格式
        :param shuffle:  是否打乱顺序
        """
        pred_index_base = [] # 存储名称,穿进嵌套的类中
        formatted_string_base = [] # 存储准确率,穿进嵌套的类中
        image_list = file_utils.get_files_lists(image_dir, shuffle=shuffle) # 读取image_dir
        cnt = len(image_list) # 统计长度
        for path in image_list:
            image = image_utils.read_image_ch(path, use_rgb=use_rgb)  # 读取图像文件,对该文件进行解码并将其转换为numpy数组。
            pred_index, pred_score = self.predict(image)  # 预测类别
            value = float(pred_score[0])  # 准确率位数限制
            formatted_string = "%.5f" % value
            pred_index_base.append(pred_index[0])
            formatted_string_base.append(formatted_string)

        # 嵌套QT界面
        if vis:  # 图像显示
            class MainWindow(QMainWindow, Ui_MainWindow):
                def __init__(self):
                    super().__init__()
                    # 设置UI
                    self.setupUi(self)
                    self.tmp = 0 # 计数器
                    self.pushButton.clicked.connect(self.change_image) # 按钮

                def change_image(self):
                    if self.tmp == cnt: # 达到最后一张照片则
                        sys.exit() # 关闭界面
                    else:
                        self.textBrowser.clear() # 清空上一次的信息
                        self.textBrowser_2.clear()
                        self.textBrowser_3.clear()
                        self.textBrowser_4.clear()
                        self.textBrowser_5.clear()
                        self.textBrowser_6.clear()
                        path = image_list[self.tmp]
                        # 将图片设置为QLabel的内容
                        self.label_2.setPixmap(QtGui.QPixmap(path)) # 图片
                        self.textBrowser.append(pred_index_base[self.tmp]) # 名称
                        self.textBrowser_2.append(str(formatted_string_base[self.tmp])) # 准确率
                        self.BinarySearch(pred_index_base[self.tmp]) # 详情
                        self.tmps() # 计数器

                def tmps(self):
                    self.tmp += 1
                # 信息输出
                def BinarySearch(self, path):
                    # 读取中药名称
                    words = []  # 存储txt中药材名称
                    with open('class_names.txt', 'r') as f:  # 读取中药名称
                        for line in f:
                            word = line.strip()  # strip()方法用于去除每行末尾的字符(默认是空格和换行符)
                            words.append(word)  # 将当前行读到的单词添加到列表中
                    # 查找
                    with open('Chinese_herbal.txt', 'r') as t:  # 读取中药信息
                        temp = t.read()
                        line_list = temp.splitlines()
                        if len(path) <= 2:  # 二分查找思想
                            for i in range(91):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])
                        else:
                            for i in range(91, 163):
                                if words[i] == path:
                                    # 输出
                                    lines = line_list[i + 1].split('|')
                                    self.textBrowser_3.append(lines[2])
                                    self.textBrowser_4.append(lines[3])
                                    self.textBrowser_5.append(lines[4])
                                    self.textBrowser_6.append(lines[5])

            app = QApplication(sys.argv)
            main_win = MainWindow()
            main_win.show()
            sys.exit(app.exec_())


def get_parser():
    # 配置文件
    # config_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/config.yaml"
    # # 模型文件
    # model_file = "data/pretrained/resnet18_1.0_CrossEntropyLoss_20230404151914/model/best_model_116_98.4700.pth"
    # 配置文件
    config_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/config.yaml"
    # 模型文件
    model_file = "work_space/resnet34_1.0_CrossEntropyLoss_20230524125658/model/best_model_114_98.5700.pth"

    # 待测试图片目录
    image_dir = "data/test_images"
    # 创建用于解析命令行参数的ArgumentParser对象;用于读取命令行输入,并将输入转换为其他数据类型以供程序使用。
    parser = argparse.ArgumentParser(description="Inference Argument")
    # 使用argparse库来解析命令行参数的函数。
    # ("-c", "--config_file":参数名称,分别代表短选项何长选项;
    # help=:该选项的描述信息;default=:不提供选项则使用默认值;type=:类型)
    parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
    parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
    parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
    parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
    return parser


if __name__ == "__main__":
    # 路径——>解析器——>预测分类——>实现
    parser = get_parser()
    # cfg配置对象;setup_config解析命令行参数(parser.parse_args():命令行参数解析器, cfg_updata使用默认值)。
    cfg = setup_config.parser_config(parser.parse_args(), cfg_updata=False)
    t = Predictor(cfg)
    # (image_dir图像目录,shuffle=False表示不要打乱目录中文件的顺序(默认为True,即打乱))
    t.image_dir_predict(cfg.image_dir, shuffle=False)

项目不足

  • 内存空间大
    当识别的图片过多时,存储名称、准确率的列表需求空间较大。
  • 测试对象静态
    后续可将图片换为视频或者实时捕获。
Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐