一个简单的图像分类项目(四)编写脚本:图像加载器

9b53821b1a744cd3b6fd4a65b2a1e4fd.png

创建训练和测试的数据集,并创建加载器。lib.load_imags.py:

import glob
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from script.setting import *


# 图片处理函数
def img_loader(path):
    try:
        img = Image.open(path)
        img = img.convert('RGB')  # 转换成RGB模式
        if img.size != normalize_size:  # 标准化图片尺寸
            img = transforms.Resize(normalize_size)(img)  # 提供新尺寸
        return img
    except Exception as e:
        print(f"Error loading image {path}: {e}")
        return error_default_img  # 返回一个默认的错误图像


# 训练集数据预处理方法
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),  # 随机垂直翻转
    transforms.RandomRotation(90),  # 随机旋转90度
    # transforms.RandomGrayscale(p=0.1),  # 随机将图片转换为灰度图,p=0.1表示有10%的概率执行该操作
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # 调整图片的亮度、对比度、饱和度和色调
    transforms.ToTensor(),  # 将图片转换为Tensor
    transforms.Normalize(normalize_mean,  # 标准化
                         normalize_std)
])

# 测试集数据预处理方法
test_transform = transforms.Compose([
    transforms.ToTensor(),  # 将图片转换为Tensor
    transforms.Normalize(normalize_mean,  # 标准化
                         normalize_std)
])


# 自定义数据集
class MyDataset(Dataset):
    def __init__(self, img_list,  # 图片的地址列表
                 transform=None):
        super(MyDataset, self).__init__()

        imgs = []  # 图片的路径和类别标签列表(转换成了0-9的数字的)存储格式为:
        # [[图片路径, 0-9的图片类别], ...]

        for img_path in img_list:
            im_label_name = img_path.split('\\')[-2]  # 图片所属类别的名称,这里使用的是绝对路径,
            # 文件目录分隔符为反斜杠,使用相对路径则为正斜杠
            imgs.append([img_path, label_dict[im_label_name]])  # 将图片路径和对应的类别标签添加到列表中

        self.imgs = imgs
        # print(imgs[0])
        self.transform = transform
        self.loader = img_loader  # 图片加载函数

    def __getitem__(self, index):  # 获取数据集的图片和标签
        img_path, label = self.imgs[index]
        img_data = self.loader(img_path)  # 图片的数据

        if img_data is None:
            return None, label  # 处理加载错误的情况

        if self.transform is not None:
            img_data = self.transform(img_data)  # 图片数据转换

        return img_data, label  # 返回图片数据和标签

    def __len__(self):  # 获取数据集的图片数量  train_path
        return len(self.imgs)


# 获取训练集的文件名
train_list = glob.glob(train_path + '\\*\\*')

# 获取测试集的文件名
test_list = glob.glob(test_path + '\\*\\*')

# 定义训练数据集
trans_dataSet = MyDataset(train_list, transform=train_transform)
# 训练集图片数量
train_num = len(trans_dataSet)

# 定义测试数据集
test_dataSet = MyDataset(test_list, transform=test_transform)
# 测试集图片数量
test_num = len(test_dataSet)

# 定义训练集的加载器
train_loader = DataLoader(trans_dataSet, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# 定义测试集的加载器
test_loader = DataLoader(test_dataSet, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# print("num_of_train", len(train_loader))  # 391(50000/128),相当于有391个batch,每个batch有128个样本
# print("num_of_test", len(test_loader))  # 79(10000/128),相当于有79个batch,每个batch有128个样本

上一篇:java_方法递归调用


下一篇:Pytest用例执行顺序和跳过执行详解