加载顺序
pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环调用dataloader对象,获取data,label数据拿到模型中去训练
Dataset
你需要自己定义一个class继承父类Dataset,其中至少需要重写以下3个函数:
①__init__:传入数据,或者加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__: 返回一条训练数据,并将其转换成tensor
示例代码:
class MyData(Dataset):
def __init__(self, x_patches, y_patches, transform = None):
self.y_patches = clean_patches
self.x_patches = blur_patches
self.transform = transform
def __len__(self):
return len(self.y_patches)
def __getitem__(self, idx):
y_image = self.y_patches[idx]
x_image = self.x_patches[idx]
y_image = np.asarray(y_image)
x_image = np.asarray(x_image)
y_image = Image.fromarray(y_image.astype(np.uint8))
x_image = Image.fromarray(x_image.astype(np.uint8))
if self.transform:
y_image = self.transform(y_image)
x_image = self.transform(x_image)
return x_image, y_image
DataLoader
参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:这个参数可以自己操作每个batch的数据 参考:https://blog.csdn.net/kahuifu/article/details/108654421
示例代码:
dataset = MyData(x_patches, y_patches, transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]))
bs = 16
data_loader = DataLoader(dataset, batch_size=bs, shuffle=True)
num_batches = len(data_loader)
调用DateLoader
最后循环调用dataloader ,拿到数据放入模型进行训练
for n_batch, (x_batch, y_batch) in enumerate(data_loader):
x_data = x_batch.float().cuda()
y_data = y_batch.float().cuda()