Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)
1.DataLoader
torch.utils.data.DataLoader()
:构建可迭代的数据装载器, 训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
Dataloader()参数:
- dataset: Dataset类,决定数据从哪读取(数据路径)以及如何读取(做哪些预处理)
- batchsize: 批大小
- num_works: 是否采用多进程读取机制
- shuffle: 每一个epoch是否乱序
- drop_last: 当样本数不能被batchsize整除时,是否舍弃最后一批数据。
2. Dataset
torch.utils.data.Dataset()
:Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()
这个类方法。
__getitem__
方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 看上面的函数,参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。
2.1 ImageFolder
torchvision已经预先实现了常用的Dataset, 其他预先实现的有: torchvision.datasets.CIFAR10
, 可以读取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集。
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
参数:
- root: 图片路径
- transform: 对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
- target_transform:对label的转换
- loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
示例:
文件夹格式:
train_path = r‘datasets/myDataSet/train‘
预处理格式:
train_transform = transforms.Compose([
transforms.Resize((40,40)),
transforms.RandomCrop(40,padding=4),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],
[0.229,0.224,0.225],)
])
dataset:
trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元组类型,第30号图片的(像素信息,label)
Data.DataLoader:
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
print(i)
print(img.shape) # (batchsize, channel, H, W)
print(target.shape) # (batch)
print(target) # 一个batch图片对应的label
2.2
class myDataset(Data.Dataset):
def __init__(self, path, transform):
self.path = path
self.transform = transform
self.data_info = self.get_img_info(path)
self.label = []
for i in range(len(self.data_info)):
self.label.append(list(self.data_info[i])[1])
def __getitem__(self, idx):
path_img = self.data_info[idx][0]
label = self.label[idx]
img = Image.open(path_img).convert(‘RGB‘) # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label, idx
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith(‘.jpg‘), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = int(sub_dir)
data_info.append((path_img, int(label)))
return data_info
trainset = myDataset(train_path, train_transform)
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
print(i)
print(img.shape) # (batchsize, channel, H, W)
print(target.shape) # (batch)
print(target) # 一个batch的图片对应的label
print(index) # 一个batch的图片在数据集中对应的index
s