图像分类数据集—FashionMNIST数据集
①简介:fashionmnist数据集*有10种类别的服饰,分别为:
['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']
部分服饰为:
②具体介绍:在该数据集*有7万张图片,每张图片的形状为:[单通道,长28,宽28],并且每张图片对应一种服饰(一种标签)。其中训练集和测试集的图片是分开的,分别有6万张图片和1万张图片。
③探索FashionMNIST数据集
导入相应的库,并下载数据集
%matplotlib inline
import torch
from IPython import display
import torchvision # torchvision是关于图像操作的一些方便工具库,对于计算机视觉进行实现的一个库
from torch.utils import data # 用来读取数据
from torchvision import transforms # 为pytorch中图像预处理包,包含了很多种对图像进行变化的函数
from d2l import torch as d2l
import matplotlib.pyplot as plt
import time
def use_svg_display():
# 用矢量图显示图片
display.set_matplotlib_formats('svg') # format格式
use_svg_display() # 用svg显示图片,这样图片的清晰度会更高
# 下载数据集
trans = transforms.ToTensor() # 把shape为(x, y, z)的转换为(z, x, y),并每个元素除以255
# 得到每个元素的数值均在0到1之间
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)
数据集的探索
len(mnist_train), len(mnist_test)
# answer (60000, 10000) 训练集60000张,测试集10000张
mnist_train[0][0].shape
# torch.Size([1, 28, 28]) 单张图片的通道数和尺寸
数据集的可视化,结果为简介中的图片
def get_fashion_mnist_labels(labels):
"""返回Fashion-MNIST数据集的文本标签。"""
test_labels = ['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']
return [test_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): # 该函数还未研究
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
images = show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
images
plt.savefig('部分服饰.png', facecolor='white', edgecolor='red') # 生成图片的保存
④导入数据集
把数据集通过函数形式导入到内存中
def load_data_fashion_mnist(batch_size, resize=None):
"""加载Fashion-MNIST数据集到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))# 把图片放大成resize * resize大小
trans = transforms.Compose(trans) # 串联多个图片变换的操作
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans)
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans)
return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
解释两个参数的含义:
batch_size:我们一次读取多少张图片
resize:是否要对图片进行等比例的放大或缩小。eg: resize=66,则图片的尺寸变为66 x 66
⑤加载数据集
train_iter, test_iter = load_data_fashion_mnist(8, 12)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
结果为:torch.Size([8, 1, 12, 12]) torch.float32 torch.Size([8]) torch.int64
说明:我们一次读取8张图片,每张图片为单通道,尺寸为12 x 12,并且每张图片都有对应的标签,一共8个标签。
⑥查看单张图片
for X, y in test_iter:
print(X[0].tolist(), y[0])
break
结果为:
[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.019607843831181526, 0.1411764770746231, 0.019607843831181526, 0.003921568859368563, 0.10196078568696976, 0.062745101749897], [0.0, 0.0, 0.0, 0.0, 0.0, 0.007843137718737125, 0.18431372940540314, 0.4745098054409027, 0.4745098054409027, 0.43921568989753723, 0.47058823704719543, 0.11372549086809158], [0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563, 0.125490203499794, 0.38823530077934265, 0.5333333611488342, 0.6039215922355652, 0.6352941393852234, 0.5803921818733215, 0.1921568661928177], [0.0, 0.003921568859368563, 0.003921568859368563, 0.03529411926865578, 0.14901961386203766, 0.3803921639919281, 0.4588235318660736, 0.5607843399047852, 0.5921568870544434, 0.6117647290229797, 0.5921568870544434, 0.3843137323856354], [0.08235294371843338, 0.1921568661928177, 0.26274511218070984, 0.3607843220233917, 0.4431372582912445, 0.4745098054409027, 0.5254902243614197, 0.5764706134796143, 0.6078431606292725, 0.6078431606292725, 0.6196078658103943, 0.5176470875740051], [0.33725491166114807, 0.47058823704719543, 0.5058823823928833, 0.49803921580314636, 0.5137255191802979, 0.5647059082984924, 0.6078431606292725, 0.6392157077789307, 0.6941176652908325, 0.800000011920929, 0.7686274647712708, 0.5333333611488342], [0.0470588244497776, 0.12156862765550613, 0.24313725531101227, 0.30588236451148987, 0.32156863808631897, 0.3176470696926117, 0.2235294133424759, 0.11764705926179886, 0.20392157137393951, 0.35686275362968445, 0.3176470696926117, 0.20000000298023224], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
对应的标签为:tensor(9),说明为第9种类型的服饰
结束!!!✨✨✨✨✨✨
完整代码链接:FashionMNIST数据集