AI learning 用于学习train,验证val的图片目录结构形式如下:(一般为自己构造的图像数据的目录)

这是一个简单的图像二分类问题,两个类别为正常(normal)或者异常(abnormal)。

数据集分为:train 训练集,val 验证集,test 测试集。

data---

        ---train               

                -----abnormal

                        ----001.jpg

                        ----002.jpg

                        ----....

                -----normal            

                        ----001.jpg

                        ----002.jpg

                        ----....

        ---val            

                -----abnormal

                        ----001.jpg

                        ----002.jpg

                        ----....

                -----normal

                        ----001.jpg

                        ----002.jpg

                        ----....

        ---test

                ----001.jpg

                ----002.jpg

                ----....

使用Dataset 继承,需要重新写自己的dataset函数,包含标签(abnormal 为标签1, normal 为标签0),有标签的情况主要是用于学习和验证使用。

from torch.utils.data import Dataset
from torchvision import transforms

def get_label(root, phase):
    label_list =[]
    img_list1 = []
    img_root = os.path.join(root,phase)
    imgs = os.listdir(img_root)
    for im in imgs:
        image_list = os.listdir(os.path.join(img_root,im))
        for img_path in image_list:
            img_list = os.path.join(os.path.join(img_root, im),img_path)
            label = 1 if img_list.split('\\')[-2] == 'abnormal' else 0
            label_list.append(label)
            img_list1.append(img_list)
    return img_list1, label_list

class MyData(Dataset):
    def __init__(self, root_dir, phase, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.phase = phase
        self.data = self.load_img(self)
    def load_img(self):
        image_list, label_list=get_label(self.root_dir,self.phase)
        data =[]
        for im in range(len(image_list)):
            img = Image.open(image_list[im]).convert('RGB')
            sample =(img,label_list[im])
            data.append(sample)
        return data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        image_info,img_label = self.data[index]
        if self.transform:
            sample = self.transform(image_info)
        else:
            sample = image_info
        return sample,img_label

无标签的情况,主要是来进行测试用。

def get_images(root):
    img_list1 = []
    img_root = os.path.join(root)
    imgs = os.listdir(img_root)
    for im in imgs:
        image_list = os.path.join(img_root,im)
        img_list1.append(image_list)
    return img_list1

class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = self.load_img()
    def load_img(self):
        image_list =get_images(self.root_dir)
        data =[]
        for im in image_list:
            img = Image.open(im).convert('RGB')
            data.append(img)
        return data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        image_info = self.data[index]
        if self.transform:
            sample = self.transform(image_info)
            return sample
        else:
            return image_info

调用

from torch.utils.data import DataLoader,Dataset

data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])
train_dataset = MyData('data','train', transform = data_transform)
val_dataset = MyData('data','val', transform = data_transform)
test_dataset = MyDataset('data\\test', transform = data_transform)
test_loader = Dataloader(test_dataset, batch_size = 32)
for step,data in enumerate(test_loader):
    images = data
  # [预测代码]

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐