PyTorch - fashion-MNIST数据集的使用

FashionMNIST数据集

Fashion-MNIST是一个10类服饰分类数据集, 我们可以使用它来检验不同算法的表现, 这是MNIST数据集不能做到的(原因在这里,想了解的可以看看介绍)。

torchvision的结构

torchvision包包含了很多图像相关的数据集以及处理方法, 并且有常用的模型结构。

  • torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  • torchvision.datasets: 一些加载数据的函数及常用的数据集接口;

  • torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;

  • torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;

  • torchvision.utils: 其他的一些有用的方法。

# 导入需要的包
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
import matplotlib.pyplot as plt

加载数据

设置数据的缓存目录为 root_dir

随后获得训练集和测试集数据,第一次运行的时候会下载 FashionMNIST 数据集到指定的目录下

下载速度慢解决方案: Gitee 极速下载 Fashion-MNIST

将Fashion-MNIST/ data / fashion的四个压缩文件解压到指定的目录,不要删除原来的压缩包文件,因此数据集总共有八个文件。

# 通过标签得到描述语句
def get_f_mnist_labels(labels):
    """

    :param labels: 图片对应的标签(0-9的数字)
    :return: 标签对应的描述
    """
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def show_fashion_mnist(images, labels):
    """

    :param images: 读取的图片
    :param labels: 图片对应的标签
    :return: None, 输出图片,并且在图片上方对应标签给出描述
    """
    _, figs = plt.subplots(1, len(images), figsize=(12, 2))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)))
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

root_dir = "./torchvision/data/"
f_mnist_train = FashionMNIST(root=root_dir, train=True, download=True, transform=transforms.ToTensor())
f_mnist_test = FashionMNIST(root=root_dir, train=False, download=True, transform=transforms.ToTensor())

print("f_mnist_train length:", len(f_mnist_train), end='\n')
print("f_mnist_test length:", len(f_mnist_test), end='\n')

x, y = [], []
for i in range(10):
    x.append(f_mnist_train[i][0])
    y.append(f_mnist_train[i][1])
show_fashion_mnist(x, get_f_mnist_labels(y))
f_mnist_train length: 60000
f_mnist_test length: 10000

PyTorch - fashion-MNIST数据集的使用

读取小批量数据

from torch.utils.data import DataLoader

batch_size = 256
train_iter = DataLoader(f_mnist_train, batch_size, shuffle=True, num_workers = 0)

# 计算加载数据的时间
import time
start = time.time()
for X, y in train_iter:
    continue
print("read train data cost %.4f seconds" % (time.time()-start))

read train data cost 4.9213 seconds

注意

本章的介绍思路来自 Apple Store的 “Python AI” app, 作为学习目的使用, 以及在此文章中记录学习过程(如有侵权,请联系作者删除。)

PyTorch - fashion-MNIST数据集的使用

上一篇:7 Fashion 数据识别


下一篇:Tensorflow 变分自编码器:Fashion MNIST图片的重建与生成