使用pytorch导入自建数据集
以mini_imagenet为例
其实是关键需要数据集的结构为
data train 类别1 image1 image2 …… 类别2 image1 image2 …… test 类别1 image1 image2 …… 类别2 image1 image2 …… val(可选) 类别1 image1 image2 …… 类别2 image1 image2 ……
import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets from wideresnet import WideResNet BATCH_SIZE = 4 transform_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理 # 需要更多数据预处理,自己查 ]) transform_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理 # 需要更多数据预处理,自己查 ]) #读取数据 dataset_train = datasets.ImageFolder('./train', transform_train) dataset_test = datasets.ImageFolder('./test', transform) #dataset_val = datasets.ImageFolder('data/val', transform) # 上面这一段是加载测试集的 train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True) # 训练集 test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True) # 测试集 #val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) # 验证集 # 对应文件夹的label print(dataset_train.class_to_idx) # 这是一个字典,可以查看每个标签对应的文件夹,也就是你的类别。 # 训练好模型后输入一张图片测试,比如输出是99,就可以用字典查询找到你的类别名称 print(dataset_test.class_to_idx) #print(dataset_val.class_to_idx) if __name__ == '__main__': model = WideResNet(40, 100, 4, 0.0) for batch_idx, (images, labels) in enumerate(train_loader): # compute output outputs = model(images) print(data.shape) print(target)