pytorch 中 DataLoader 和 Dataset 的使用

加载顺序

pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环调用dataloader对象,获取data,label数据拿到模型中去训练

Dataset

你需要自己定义一个class继承父类Dataset,其中至少需要重写以下3个函数:
①__init__:传入数据,或者加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__: 返回一条训练数据,并将其转换成tensor

示例代码:

class MyData(Dataset):

  def __init__(self, x_patches, y_patches, transform = None):
    self.y_patches = clean_patches
    self.x_patches = blur_patches
    self.transform = transform

  def __len__(self):
    return len(self.y_patches)

  def __getitem__(self, idx):
    y_image = self.y_patches[idx]
    x_image = self.x_patches[idx]

    y_image = np.asarray(y_image)
    x_image = np.asarray(x_image)

    y_image = Image.fromarray(y_image.astype(np.uint8))
    x_image = Image.fromarray(x_image.astype(np.uint8))

    if self.transform:
       y_image = self.transform(y_image)
       x_image = self.transform(x_image)

    return x_image, y_image

DataLoader

参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:这个参数可以自己操作每个batch的数据 参考:https://blog.csdn.net/kahuifu/article/details/108654421

示例代码:

dataset = MyData(x_patches, y_patches, transform=transforms.Compose(
            [transforms.ToTensor(), 
             transforms.Normalize([0.5], [0.5])]))

bs = 16
data_loader = DataLoader(dataset, batch_size=bs, shuffle=True)
num_batches = len(data_loader)

调用DateLoader

最后循环调用dataloader ,拿到数据放入模型进行训练

for n_batch, (x_batch, y_batch) in enumerate(data_loader):

    x_data = x_batch.float().cuda()
    y_data = y_batch.float().cuda()

上一篇:编译问题处理方法


下一篇:Android Studio 更新失败的解决办法