Fashion-MNIST是一个10类服饰分类数据。
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib as plt
import time
我们将使用torchcvision包,它服务于Pytorch深度学习框架的,由以下几部分构成
torchvisiion.datasets : 一些加载数据的函数和常用数据集接口
torchvision.models : 包含常用的模型结构(含预训练模型), 例如AlexNet,VGG等
torchvision.transforms : 常用的图片变换,例如裁剪,旋转等
torchvision.utils : 其他的一些有用方法。
首先通过 torchvision.datasets来下载数据集, 通过参数train来指定获取训练集还是测试集,
transform = transforms.ToTensor()使所有数据的每一个像素数组转换为(0.0, 1.0)的32位浮点数的Tensor。feature尺寸是( C, H, W), 第一维是通道数,
mnist_train 和 mnist_test 都是torch.utils.data.Dataset的子类, 可以使用len()方法获取数据集的大小。
mnist_train = torchvision.datasets.FashionMNIST(
root='D:~/Fashion_MNIST',
train = True,
download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(
root='D:~/Fashion_MNIST',
train=False,
download=True,
transform=transforms.ToTensor())
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
如果出现下面的错误:
解决办法:
访问任意样本: 每一个数据都是有图片和标签,查看第0张图片形状和标签。
feature , label = mnist_train[0]
print(feature.shape, label)