使用pytorch导入自建数据集

使用pytorch导入自建数据集
以mini_imagenet为例
其实是关键需要数据集的结构为

data
	train
		类别1
			image1
			image2
			……
		类别2
			image1
			image2
			……
	test
		类别1
			image1
			image2
			……
		类别2
			image1
			image2
			……
	val(可选)
		类别1
			image1
			image2
			……
		类别2
			image1
			image2
			……
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from wideresnet import WideResNet

BATCH_SIZE = 4
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理
 # 需要更多数据预处理,自己查
])
transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理
 # 需要更多数据预处理,自己查
])

#读取数据
dataset_train = datasets.ImageFolder('./train', transform_train)
dataset_test = datasets.ImageFolder('./test', transform)
#dataset_val = datasets.ImageFolder('data/val', transform)

# 上面这一段是加载测试集的
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True) # 训练集
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True) # 测试集
#val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) # 验证集
# 对应文件夹的label
print(dataset_train.class_to_idx)   # 这是一个字典,可以查看每个标签对应的文件夹,也就是你的类别。
                                    # 训练好模型后输入一张图片测试,比如输出是99,就可以用字典查询找到你的类别名称
print(dataset_test.class_to_idx)
#print(dataset_val.class_to_idx)


if __name__ == '__main__':
    model = WideResNet(40, 100, 4, 0.0)
    for batch_idx, (images, labels) in enumerate(train_loader):
        # compute output
        outputs = model(images)
        print(data.shape)
        print(target)

 

上一篇:CSS3 边框彩虹跑马灯


下一篇:python opencv图片拼接源码