本文来源公众号“Coggle数据科学”,仅用于学术分享,侵权删,干货满满。

原文链接:科大讯飞AI大赛:中医药识别挑战赛 Baseline

  • 赛题名称:中医药识别挑战赛

  • 赛题类型:计算机视觉、图像分类

  • 赛题任务:对中医药图片的自动识别与分类

https://challenge.xfyun.cn/topic/info?type=tcm-recognition&ch=dwsf259

赛题背景

随着中医药产业规模的不断扩大,中药质量参差不齐的问题逐渐凸显。中药品种繁多、来源复杂,且存在大量外观相似的品种,仅凭传统的人工经验识别,不仅效率低下,还容易出现误判,难以满足现代中医药产业快速发展的需求。例如在中药生产企业中,中药材的采购验收环节,若因识别错误混入伪劣品,将直接影响中成药的质量与疗效;在医疗机构,中药饮片的准确识别关乎处方用药的安全与有效性。

在此背景下,人工智能(AI)技术凭借其强大的数据处理和识别能力,为解决中医药识别难题提供了新的途径。利用AI算法对海量的中医药图像、特征数据进行学习和分析,能够实现快速、精准的中医药识别,有效提升中医药质量控制水平,推动中医药产业的规范化、现代化发展。“中医药识别挑战赛”通过AI算法,突破中医药识别的技术瓶颈,为中医药产业的高质量发展注入新动能。

赛题任务

本次比赛旨在通过人工智能技术实现对中医药图片的自动识别与分类。比赛的参赛者需要基于提供的中医药图片数据集,设计和训练一个能够高效、准确地识别和分类不同中医药的人工智能模型。

比赛的评判标准为模型的准确率、召回率、精确率等指标。通过本次比赛,我们期待能够开发出一种高效、准确的算法,以提高中医药识别能力,从而实现有效提升中医药质量控制水平能力。

数据说明

所提供的数据大约包含54个中医药类别和对应的中医药图片,图像分为训练数据库和测试数据库两个子数据库。

其中中医药图片名称对应的label如下表:

category

label

category

label

Honeysuckles

0

Boheye

27

Gardenia

1

Hairyveinagrimony

28

Tianhukui

2

Feipeng

29

Gouweibacao

3

Guizhencao

30

Shuiqincai

4

Eichhorniacrassipes

31

Morningglory

5

Dandelions

32

Bosipopona

6

Zhajiangcao

33

Mantuoluo

7

Wahlenbergia

34

Tongquancao

8

Radixisatidis

35

Perillas

9

Mangnoliaofficinalis

36

Jicai

10

Odoratum

37

Xiaoji

11

Cangerzi

38

Angelica

12

Commelina_communis

39

Heshouwu

13

Chenopodiumalbum

40

Yichuanhong

14

Monochoriavaginalis

41

Malan

15

Ziyunying

42

Rabdosiaserra

16

Pinellia

43

Zeqi

17

Hongliao

44

Bupleurum

18

Moneygrass

45

Plantains

19

Lotusseed

46

Ginsengs

20

Ophiopogon

47

Juaner

21

Qigucao

48

Kucai

22

Huanghuacai

49

Selfheals

23

Wormwood

50

Sedum_sarmentosum

24

Palms

51

Agastacherugosa

25

Denglongcao

52

Xunma

26

Xiaoqieyi

53

评估指标

本模型依据提交的结果文件,采用macro F1-score进行评价。

  • 本赛题均提供下载数据,选手根据要求提交作品。

  • 每支队伍每天最多提交3次。

  • 排行按照得分从高到低排序,将选择队伍的历史最优成绩进行排名。

赛题 Baseline

https://github.com/datawhalechina/competition-baseline/tree/master/competition/%E7%A7%91%E5%A4%A7%E8%AE%AF%E9%A3%9EAI%E5%BC%80%E5%8F%91%E8%80%85%E5%A4%A7%E8%B5%9B2025

  • 加载数据集和模型

class XFDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        try:
            img = Image.open(self.img_path[index]).convert('RGB')
        except:
            img = Image.new('RGB', (224, 224), color='white')

        label = torch.from_numpy(np.array(self.img_label[index]))
        if self.transform is not None:
            img = self.transform(img)
        
        return img, label
    
    def __len__(self):
        return len(self.img_path)
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=len(label_index))
model = model.cuda()
  • 模型训练与验证

train_loader = torch.utils.data.DataLoader(
    XFDataset(train_path[:-500], train_label[:-500], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=10, shuffle=True, num_workers=4, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    XFDataset(train_path[-500:], train_label[-500:], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=10, shuffle=False, num_workers=4, pin_memory=True
)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(20):
    scheduler.step()
    print('Epoch: ', epoch)

    train(train_loader, model, criterion, optimizer, epoch)
    val_acc = validate(val_loader, model, criterion)
    
    if val_acc.avg.item() > best_acc:
        best_acc = round(val_acc.avg.item(), 2)

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐