Pytorch中FASHION-MNIST

Fashion-MNIST是一个10类服饰分类数据。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib as plt
import time

我们将使用torchcvision包,它服务于Pytorch深度学习框架的,由以下几部分构成

torchvisiion.datasets : 一些加载数据的函数和常用数据集接口

torchvision.models :   包含常用的模型结构(含预训练模型), 例如AlexNet,VGG等

torchvision.transforms : 常用的图片变换,例如裁剪,旋转等

torchvision.utils : 其他的一些有用方法。

首先通过 torchvision.datasets来下载数据集, 通过参数train来指定获取训练集还是测试集,

transform = transforms.ToTensor()使所有数据的每一个像素数组转换为(0.0, 1.0)的32位浮点数的Tensor。feature尺寸是( C, H, W), 第一维是通道数,

mnist_train 和 mnist_test 都是torch.utils.data.Dataset的子类, 可以使用len()方法获取数据集的大小。

 

mnist_train = torchvision.datasets.FashionMNIST(
    root='D:~/Fashion_MNIST',
    train = True,
    download=True,
    transform=transforms.ToTensor())

mnist_test = torchvision.datasets.FashionMNIST(
    root='D:~/Fashion_MNIST',
    train=False,
    download=True,
    transform=transforms.ToTensor())

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

如果出现下面的错误:

Pytorch中FASHION-MNIST

 解决办法:

https://blog.csdn.net/bkdly9/article/details/120634060?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522163868086616780357229552%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=163868086616780357229552&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-120634060.pc_search_result_cache&utm_term=serWarning%3A+The+given+NumPy+array+is+not+writeable&spm=1018.2226.3001.4187Pytorch中FASHION-MNISThttps://blog.csdn.net/bkdly9/article/details/120634060?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522163868086616780357229552%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=163868086616780357229552&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-120634060.pc_search_result_cache&utm_term=serWarning%3A+The+given+NumPy+array+is+not+writeable&spm=1018.2226.3001.4187

访问任意样本: 每一个数据都是有图片和标签,查看第0张图片形状和标签。

feature , label = mnist_train[0]

print(feature.shape, label)

上一篇:从零开始实现一个端到端的机器学习项目[3]


下一篇:Hadoop面试题(一)