【机器学习】使用pyplot绘制MNIST数据集中的手写数字

MNIST数据集是人工智能大佬Yann LeCun给出的一套手写数字的数据集,训练集包含60,000个样本和标注,测试集包含10,000个样本和标注。可以给新手用来练手用。

数据集表示

  1. 标注:数字分为0-9,总共10个数字,标注也是从0-9,分别对应0-910个数字;
  2. 图片:将每张图片切分成2828的矩阵,矩阵的每个元素使用灰度值来表示,所以总共使用一个2828的矩阵来表示图片;

下载数据集

数据集下载地址:http://yann.lecun.com/exdb/mnist/
整个数据集分为四个部分:

  • train-images-idx3-ubyte.gz: 训练集图片 (9912422 bytes)
  • train-labels-idx1-ubyte.gz: 训练集标注 (28881 bytes)
  • t10k-images-idx3-ubyte.gz: 测试集图片 (1648877 bytes)
  • t10k-labels-idx1-ubyte.gz: 测试集标注 (4542 bytes)

解析数据集文件

在LeCun的网站上,给出了数据集的格式,需要关注的点有:

  1. 存储二进制数据;
  2. 使用大端法存储;
  3. 标注集(包括训练标注集和测试标注集):第一个字节为魔数(Magic Number),第二个字节为标注总个数(训练集-60,000,测试集10,000),后续每个字节为对应标注数值;
  4. 图片集(包括训练图片集和测试图片集):第一个字节为魔数(Magic Number),第二个字节为图片总个数(训练集-60,000,测试集10,000),第四个字节为每张图片表示矩阵的rows,第五个字节为每张图片表示矩阵的cols

绘制图片

使用pyplot来绘制0-9这10个数字总的来说有以下几个步骤:

  1. 加载图片和标注数据;
  2. 使用subplots方法创建一张2*5的画布;
  3. 将画布展开,并把0-9是个数字使用imshow方法绘制进去;
  4. plt.show();

具体代码

  1. 加载数据(load_data.py)
import os
import struct
import numpy as np


def load_mnist(path, kind='train'):
    """拼接路径"""
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)

    with open(labels_path, 'rb') as lbpath:
    	"""使用大端法读取2个字节,第一个是魔数,第二个是个数"""
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        """依次读取标注值"""
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        """使用大端法读取4个字节,第一个是魔数,第二个是个数,三四分别是rows、cols"""
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        """依次读取值,并reshape为length*784的矩阵"""
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

  1. 绘制图片(main.py)
import load_data
import matplotlib.pyplot as plt

"""加载数据"""
images, labels = load_data.load_mnist('/Users/wowo/Documents/0-Tensorflow')

"""创建画布"""
fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True,
)

"""平铺画布"""
ax = ax.flatten()
for i in range(10):
    """获取数据集中第一次出现的0-9数字,并reshape到28*28的矩阵"""
    img = images[labels == i][0].reshape(28, 28)
    """绘制数字"""
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

"""隐藏横纵坐标"""
ax[0].set_xticks([])
ax[0].set_yticks([])
"""美化画布,使之更紧凑"""
plt.tight_layout()
"""绘制画布"""
plt.show()

参开文献

  1. http://yann.lecun.com/exdb/mnist/
  2. https://www.cnblogs.com/xianhan/p/9145966.html
上一篇:python数据可视化 | matplotlib.pyplot()函数绘制线形图,感受数据直观变化


下一篇:QML 地图修改插件源码(三),Map在Plugin中设置加载地图类型