1.torch.utils.data.random_split()
pytorch有多种方法划分,但这个是最简单的。
转自:https://www.cnblogs.com/marsggbo/p/10496696.html
train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
划分完了之后训练和测试集的类型是:
<class 'torch.utils.data.dataset.Subset'>
由原来的Dataset类型变为Subset类型,两者都可以作为torch.utils.data.DataLoader()的参数构建可迭代的DataLoader。
2.torch.utils.data.Subset()
https://*.com/questions/47432168/taking-subsets-of-a-pytorch-dataset
import torchvision import torch trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None) evens = list(range(0, len(trainset), 2))#偶数位 odds = list(range(1, len(trainset), 2))#奇数位 trainset_1 = torch.utils.data.Subset(trainset, evens)#Subset类型 trainset_2 = torch.utils.data.Subset(trainset, odds)#Subset类型 #由Subset对象构建DataLoader trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4, shuffle=True, num_workers=2) trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4, shuffle=True, num_workers=2)
3.