创建训练和测试的数据集,并创建加载器。lib.load_imags.py:
import glob
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from script.setting import *
# 图片处理函数
def img_loader(path):
try:
img = Image.open(path)
img = img.convert('RGB') # 转换成RGB模式
if img.size != normalize_size: # 标准化图片尺寸
img = transforms.Resize(normalize_size)(img) # 提供新尺寸
return img
except Exception as e:
print(f"Error loading image {path}: {e}")
return error_default_img # 返回一个默认的错误图像
# 训练集数据预处理方法
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomVerticalFlip(), # 随机垂直翻转
transforms.RandomRotation(90), # 随机旋转90度
# transforms.RandomGrayscale(p=0.1), # 随机将图片转换为灰度图,p=0.1表示有10%的概率执行该操作
# transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), # 调整图片的亮度、对比度、饱和度和色调
transforms.ToTensor(), # 将图片转换为Tensor
transforms.Normalize(normalize_mean, # 标准化
normalize_std)
])
# 测试集数据预处理方法
test_transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor
transforms.Normalize(normalize_mean, # 标准化
normalize_std)
])
# 自定义数据集
class MyDataset(Dataset):
def __init__(self, img_list, # 图片的地址列表
transform=None):
super(MyDataset, self).__init__()
imgs = [] # 图片的路径和类别标签列表(转换成了0-9的数字的)存储格式为:
# [[图片路径, 0-9的图片类别], ...]
for img_path in img_list:
im_label_name = img_path.split('\\')[-2] # 图片所属类别的名称,这里使用的是绝对路径,
# 文件目录分隔符为反斜杠,使用相对路径则为正斜杠
imgs.append([img_path, label_dict[im_label_name]]) # 将图片路径和对应的类别标签添加到列表中
self.imgs = imgs
# print(imgs[0])
self.transform = transform
self.loader = img_loader # 图片加载函数
def __getitem__(self, index): # 获取数据集的图片和标签
img_path, label = self.imgs[index]
img_data = self.loader(img_path) # 图片的数据
if img_data is None:
return None, label # 处理加载错误的情况
if self.transform is not None:
img_data = self.transform(img_data) # 图片数据转换
return img_data, label # 返回图片数据和标签
def __len__(self): # 获取数据集的图片数量 train_path
return len(self.imgs)
# 获取训练集的文件名
train_list = glob.glob(train_path + '\\*\\*')
# 获取测试集的文件名
test_list = glob.glob(test_path + '\\*\\*')
# 定义训练数据集
trans_dataSet = MyDataset(train_list, transform=train_transform)
# 训练集图片数量
train_num = len(trans_dataSet)
# 定义测试数据集
test_dataSet = MyDataset(test_list, transform=test_transform)
# 测试集图片数量
test_num = len(test_dataSet)
# 定义训练集的加载器
train_loader = DataLoader(trans_dataSet, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# 定义测试集的加载器
test_loader = DataLoader(test_dataSet, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# print("num_of_train", len(train_loader)) # 391(50000/128),相当于有391个batch,每个batch有128个样本
# print("num_of_test", len(test_loader)) # 79(10000/128),相当于有79个batch,每个batch有128个样本