mxnet 图像分类模型训练教程

mxnet 图像识别教程

代码:
https://github.com/dwSun/classification-tutorial.git

这里以 TinyMind 《汉字书法识别》比赛数据为例,展示使用 mxnet 进行图像数据分类模型训练的整个流程。

数据地址请参考:
https://www.tinymind.cn/competitions/41#property_23

或到这里下载:
*练习赛数据下载地址:
训练集:链接: https://pan.baidu.com/s/1UxvN7nVpa0cuY1A-0B8gjg 密码: aujd

测试集: https://pan.baidu.com/s/1tzMYlrNY4XeMadipLCPzTw 密码: 4y9k

数据探索

请参考官方的数据说明

数据处理

竞赛中只有训练集 train 数据有准确的标签,因此这里只使用 train 数据即可,实际应用中,阶段 1、2 的榜单都需要使用。

数据下载

下载数据之后进行解压,得到 train 文件夹,里面有 100 个文件夹,每个文件夹名字即是各个汉字的标签。类似的数据集结构经常在分类任务中见到。可以使用下述命令验证一下每个文件夹下面文件的数量,看数据集是否符合竞赛数据描述:

for l in $(ls); do echo $l $(ls $l|wc -l); done

划分数据集

因为这里只使用了 train 集,因此我们需要对已有数据集进行划分,供模型训练的时候做验证使用,也就是 validation 集的构建。

一般认为,train 用来训练模型,validation 用来对模型进行验证以及超参数( hyper parameter)调整,test 用来做模型的最终验证,我们所谓模型的性能,一般也是指 test 集上模型的性能指标。但是实际项目中,一般只有 train 集,同时没有可靠的 test 集来验证模型,因此一般将 train 集划分出一部分作为 validation,同时将 validation 上的模型性能作为最终模型性能指标。

一般情况下,我们不严格区分 validation 和 test。

这里将每个文件夹下面随机50个文件拿出来做 validation。

export train=train
export val=validation

for d in $(ls $train); do
    mkdir -p $val/$d/
    for f in $(ls train/$d | shuf | head -n 50 ); do
        mv $train/$d/$f $val/$d/;
    done;
done
需要注意,这里的 validation 只间接通过超参数的调整参与了模型训练。因此有一定的数据浪费。

模型训练代码-数据部分

首先导入 mxnet 看一下版本

import mxnet as mx

mx.__version__
'1.6.0'


训练模型的时候,模型内部全部都是数字,没有任何可读性,而且这些数字也需要人为给予一些实际的意义,这里将 100 个汉字作为模型输出数字的文字表述。

需要注意的是,因为模型训练往往是一个循环往复的过程,因此一个稳定的文字标签是很有必要的,这里利用相关 python 代码在首次运行的时候生成了一个标签文件,后续检测到这个标签文件,则直接调用即可。

import os

if os.path.exists("labels.txt"):
    with open("labels.txt") as inf:
        classes = [l.strip() for l in inf]
else:
    classes = os.listdir("worddata/train/")
    with open("labels.txt", "w") as of:
        of.write("\r\n".join(classes))

class_idx = {v: k for k, v in enumerate(classes)}
idx_class = dict(enumerate(classes))

pyTorch里面,classes有自己的组织方式,这里我们想要自定义,要做一下转换。

from PIL import Image

pth_classes = classes[:]
pth_classes.sort()
pth_classes_to_idx = {v: k for k, v in enumerate(pth_classes)}


def transform(data, pth_idx):
    return data, class_idx[pth_classes[pth_idx]]

mxnet 中提供了直接从目录中读取数据并进行训练的 API 这里使用的API如下。

这里使用了两个数据集,分别代表 train、validation。

需要注意的是,由于 数据中,使用的图像数据集,其数值在(0, 255)之间。同时,mxnet 用 opencv 来处理图像的加载,其图像的数据 layout 是(H,W,C),而 mxnet 用来训练的数据需要是(C,H,W)的,因此需要对数据做一些转换。另外,train 数据集做了一定的数据预处理(旋转、明暗度),用于进行数据增广,也做了数据打乱(shuffle),而 validation则不需要做类似的变换。

ToTensor这个操作会转换数据的 layout,因此要放在最后面。
from multiprocessing import cpu_count

transform_train = mx.gluon.data.vision.transforms.Compose(
    [
        # mx.gluon.data.vision.transforms.RandomRotation((-15, 15), zoom_out=True),
        # 带随机旋转的版本还没发布
        mx.gluon.data.vision.transforms.Resize((128, 128)),
        mx.gluon.data.vision.transforms.RandomColorJitter(brightness=0.5),
        mx.gluon.data.vision.transforms.ToTensor(),
    ]
)
transform_val = mx.gluon.data.vision.transforms.Compose(
    [
        mx.gluon.data.vision.transforms.Resize((128, 128)),
        mx.gluon.data.vision.transforms.ToTensor(),
    ]
)


img_gen_train = mx.gluon.data.vision.datasets.ImageFolderDataset(
    "worddata/train/", transform=transform, flag=0
)


img_gen_val = mx.gluon.data.vision.datasets.ImageFolderDataset(
    "worddata/validation/", transform=transform, flag=0
)

batch_size = 32

img_train = mx.gluon.data.DataLoader(
    img_gen_train.transform_first(transform_train),
    batch_size=batch_size,
    shuffle=True,
    num_workers=cpu_count(),
)
img_val = mx.gluon.data.DataLoader(
    img_gen_val.transform_first(transform_val),
    batch_size=batch_size,
    num_workers=cpu_count(),
)

到这里,这两个数据集就可以使用了,正式模型训练之前,我们可以先来看看这个数据集是怎么读取数据的,读取出来的数据又是设么样子的。

for imgs, labels in img_train:
    # img_train 只部分满足 generator 的语法,不能用 next 来获取数据
    break
imgs.shape, labels.shape
((32, 1, 128, 128), (32,))


可以看到数据是(batch, channel, height, width, height), 因为这里是灰度图像,因此 channel 是 1。

需要注意,pyTorch、mxnet使用的数据 layout 与Tensorflow 不同,因此数据也有一些不同的处理方式。

把图片打印出来看看,看看数据和标签之间是否匹配
mxnet 图像分类模型训练教程

import numpy as np
from matplotlib import pyplot as plt

plt.imshow(imgs.asnumpy()[0, 0, :, :], cmap="gray")
classes[labels.asnumpy()[0]]
'利'




模型训练代码-模型构建

mxnet 中使用动态图来构建模型,模型构建比较简单。这里演示的是使用 class 的方式构建模型,对于简单模型,还可以直接使用 Sequential 进行构建。

这里的复杂模型也是用 Sequential 的简单模型进行的叠加。

这里构建的是VGG模型,关于VGG模型的更多细节请参考 1409.1556。
class MyModel(mx.gluon.nn.HybridBlock):
    def __init__(self):
        super(MyModel, self).__init__()
        # 模型有两个主要部分,特征提取层和分类器

        # 这里是特征提取层
        self.feature = mx.gluon.nn.HybridSequential()

        self.feature.add(self.conv(64))
        self.feature.add(self.conv(64, add_pooling=True))

        self.feature.add(self.conv(128))
        self.feature.add(self.conv(128, add_pooling=True))

        self.feature.add(self.conv(256))
        self.feature.add(self.conv(256))
        self.feature.add(self.conv(256, add_pooling=True))
        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512, add_pooling=True))

        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512, add_pooling=True))
        self.feature.add(mx.gluon.nn.GlobalAvgPool2D())
        self.feature.add(mx.gluon.nn.Flatten())

        self.feature.add(mx.gluon.nn.Dense(4096, activation="relu"))
        self.feature.add(mx.gluon.nn.BatchNorm())

        self.feature.add(mx.gluon.nn.Dense(4096, activation="relu"))
        self.feature.add(mx.gluon.nn.BatchNorm())

        self.feature.add(mx.gluon.nn.Dropout(0.5))
        # 这个简单的机构是分类器
        self.pred = mx.gluon.nn.Dense(100)

    def conv(self, filters, add_pooling=False):
        # 模型大量使用重复模块构建,
        # 这里将重复模块提取出来,简化模型构建过程
        model = mx.gluon.nn.HybridSequential()
        model.add(mx.gluon.nn.Conv2D(filters, 3, padding=1, activation="relu"))
        model.add(mx.gluon.nn.BatchNorm())

        if add_pooling:
            model.add(mx.gluon.nn.MaxPool2D(strides=2))
        return model

    def hybrid_forward(self, F, x):
        # call 用来定义模型各个结构之间的运算关系

        x = self.feature(x)
        return self.pred(x)

上面使用的是 gluon api 相比于 gluon api,mxnet 还有纯符号和纯静态图的方式,但是不如 gluon api 方便,就像 keras 之于 Tensorflow, gluon api 也是 mxnet 社区主推的方式。

HybridBlock 这种带 Hybrid 前缀的模块,底层可以编译成静态图,速度快一些,可以尽量多用一下。不带Hybrid 前缀的模块使用起来不太一样。

实例化一个模型看看:

ctx = mx.gpu(0)
model = MyModel()
model.initialize(ctx=ctx, init=mx.initializer.Xavier())
model.hybridize()

model
MyModel(
  (feature): HybridSequential(
    (0): HybridSequential(
      (0): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (1): HybridSequential(
      (0): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    )
    (2): HybridSequential(
      (0): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (3): HybridSequential(
      (0): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    )
    (4): HybridSequential(
      (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (5): HybridSequential(
      (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (6): HybridSequential(
      (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    )
    (7): HybridSequential(
      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (8): HybridSequential(
      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (9): HybridSequential(
      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    )
    (10): HybridSequential(
      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (11): HybridSequential(
      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    )
    (12): HybridSequential(
      (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), Activation(relu))
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
      (2): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False, global_pool=False, pool_type=max, layout=NCHW)
    )
    (13): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True, global_pool=True, pool_type=avg, layout=NCHW)
    (14): Flatten
    (15): Dense(None -> 4096, Activation(relu))
    (16): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (17): Dense(None -> 4096, Activation(relu))
    (18): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (19): Dropout(p = 0.5, axes=())
  )
  (pred): Dense(None -> 100, linear)
)


模型训练代码-训练相关部分

要训练模型,我们还需要定义损失,优化器等。

loss_object = mx.gluon.loss.SoftmaxCrossEntropyLoss()
optimizer = mx.gluon.Trainer(model.collect_params(), mx.optimizer.Adam())  # 优化器有些参数可以设置

train_accuracy = mx.metric.Accuracy()
val_accuracy = mx.metric.Accuracy()
import time  # 模型训练的过程中手动追踪一下模型的训练速度

因为模型整个训练过程一般是一个循环往复的过程,所以经常性的保存重启模型训练中间过程是有必要的。
这里我们一个ckpt保存了两份,便于中断模型的重新训练。

import os

if os.path.exists("model.params"):
    # 检查 checkpoint 是否存在
    # 如果存在,则加载 checkpoint
    model.load_parameters("model.params")

    # 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,
    # 手动选择准确率最高的某次 checkpoint 进行加载。
    print("model lodaded")
EPOCHS = 20
for epoch in range(EPOCHS):

    train_loss = 0
    train_samples = 0
    train_accuracy.reset()
    val_accuracy.reset()

    val_loss = 0
    val_samples = 0

    start = time.time()
    for imgs, labels in img_train:
        imgs = imgs.as_in_context(ctx)
        labels = labels.as_in_context(ctx)

        with mx.autograd.record():
            preds = model(imgs)
            loss = loss_object(preds, labels)
        loss.backward()

        optimizer.step(batch_size)

        train_loss += loss.sum().asscalar()
        train_accuracy.update(labels, preds)

        train_samples += imgs.shape[0]
        mx.nd.waitall()

    train_samples_per_second = train_samples / (time.time() - start)

    start = time.time()
    for imgs, labels in img_val:
        imgs = imgs.as_in_context(ctx)
        labels = labels.as_in_context(ctx)

        preds = model(imgs)
        loss = loss_object(preds, labels)

        val_loss += loss.sum().asscalar()
        val_accuracy.update(labels, preds)

        val_samples += imgs.shape[0]
        mx.nd.waitall()

    val_samples_per_second = val_samples / (time.time() - start)

    print(
        "Epoch {} Loss {}, Acc {}, Val Loss {}, Val Acc {}".format(
            epoch,
            train_loss / train_samples,
            train_accuracy.get()[1] * 100,
            val_loss / val_samples,
            val_accuracy.get()[1] * 100,
        )
    )
    print(
        "Speed train {}imgs/s val {}imgs/s".format(
            train_samples_per_second, val_samples_per_second
        )
    )

    model.save_parameters("model.params")
    model.save_parameters("model-{:04d}.params".format(epoch))

    # 每个 epoch 保存一下模型,需要注意每次
    # 保存要用一个不同的名字,不然会导致覆盖,
    # 同时还要关注一下磁盘空间占用,防止太多
    # chekcpoint 占满磁盘空间导致错误。
Epoch 0 Loss 6.9573759181431365, Acc 1.042857142857143, Val Loss 15.420175067138672, Val Acc 1.68
Speed train 334.933801136163imgs/s val 791.7329888468499imgs/s
Epoch 1 Loss 6.7644785858154295, Acc 1.1114285714285714, Val Loss 21.39283991394043, Val Acc 0.9400000000000001
Speed train 344.24428987859835imgs/s val 1005.8295516339732imgs/s
Epoch 2 Loss 6.427109525844029, Acc 1.1142857142857143, Val Loss 1441.4082737342835, Val Acc 1.46
Speed train 350.2307806800212imgs/s val 1025.2414031403846imgs/s
Epoch 3 Loss 5.979110830688477, Acc 1.1114285714285714, Val Loss 402112.5587174507, Val Acc 1.08
Speed train 343.90973606373484imgs/s val 1007.538116122915imgs/s
Epoch 4 Loss 5.7047360944475445, Acc 1.0114285714285716, Val Loss 28.884018599700926, Val Acc 1.16
Speed train 346.9289205956802imgs/s val 1042.5033657692259imgs/s
Epoch 5 Loss 5.429534565952846, Acc 1.1828571428571428, Val Loss 25.10572717514038, Val Acc 1.06
Speed train 346.55915907021785imgs/s val 1042.9843522905003imgs/s
Epoch 6 Loss 5.1636578002929685, Acc 1.18, Val Loss 292.2765090805054, Val Acc 1.26
Speed train 346.677920645666imgs/s val 1041.1619895883405imgs/s
Epoch 7 Loss 4.944499089268276, Acc 1.4000000000000001, Val Loss 8.26984390258789, Val Acc 1.66
Speed train 345.7027344018046imgs/s val 1041.9690306718644imgs/s
Epoch 8 Loss 4.776889256722587, Acc 1.4857142857142858, Val Loss 5.135510383605957, Val Acc 1.9
Speed train 347.51812408602746imgs/s val 1055.2600700466792imgs/s
Epoch 9 Loss 4.627297317940848, Acc 2.2885714285714287, Val Loss 19.935467903900147, Val Acc 3.36
Speed train 345.4781099199051imgs/s val 1012.3481105170373imgs/s
Epoch 10 Loss 4.379766181509836, Acc 3.842857142857143, Val Loss 10.535897972106934, Val Acc 6.4399999999999995
Speed train 346.03842686704553imgs/s val 1046.8390845908468imgs/s
Epoch 11 Loss 3.685629348100935, Acc 11.297142857142857, Val Loss 52.809229167175296, Val Acc 19.12
Speed train 348.1024735895058imgs/s val 1005.0561635670016imgs/s
Epoch 12 Loss 2.9683764310564316, Acc 24.19142857142857, Val Loss 96.89430379104614, Val Acc 34.64
Speed train 347.82284815108676imgs/s val 1015.7570915242692imgs/s
Epoch 13 Loss 2.32016633845738, Acc 39.92285714285714, Val Loss 1393.9259933712005, Val Acc 48.64
Speed train 347.1405956344324imgs/s val 1002.6255219358517imgs/s
Epoch 14 Loss 1.7546388859340123, Acc 53.98285714285714, Val Loss 405771.5952173068, Val Acc 62.68
Speed train 347.7790321353627imgs/s val 1016.1043531852587imgs/s
Epoch 15 Loss 1.4703079643249513, Acc 61.34285714285714, Val Loss 55974.54370078631, Val Acc 67.4
Speed train 341.66829982471imgs/s val 996.3958929419091imgs/s
Epoch 16 Loss 1.0772937241145544, Acc 71.44857142857143, Val Loss 175934832.46710944, Val Acc 74.83999999999999
Speed train 333.6240690789895imgs/s val 988.0520794326872imgs/s
Epoch 17 Loss 0.9253648305075509, Acc 75.28285714285714, Val Loss 24513272.186748803, Val Acc 73.38
Speed train 324.7226062593469imgs/s val 964.7906383539004imgs/s
Epoch 18 Loss 0.817839416544778, Acc 77.78, Val Loss 21546954.254889924, Val Acc 78.3
Speed train 320.69584017889264imgs/s val 945.9076133097823imgs/s
Epoch 19 Loss 0.869460376398904, Acc 76.63142857142857, Val Loss 8773747185.209038, Val Acc 78.72
Speed train 314.57207850646853imgs/s val 936.1137303334394imgs/s

一些技巧

因为这里定义的模型比较大,同时训练的数据也比较多,每个 epoch 用时较长,因此,如果代码有 bug 的话,经过一次 epoch 再去 debug 效率比较低。

这种情况下,我们使用的数据生成过程又是自己手动指定数据数量的,因此可以尝试缩减模型规模,定义小一些的数据集来快速验证代码。在这个例子里,我们可以通过注释模型中的卷积和全连接层的代码来缩减模型尺寸,通过修改训练循环里面的数据数量来缩减数据数量。

训练的速度很慢

一开始训练速度比较慢,因为 mxnet 默认初始化方式是 uniform。

这里改成了跟 tf 一样的 xavier (tf里面叫做 glorot_uniform)之后训练速度还是比较慢,要 10 个 epoch才能看到收敛。而且 TF 里面 20epochs 能达到 90% 的准确率,这里才 76%,应该是哪里有什么问题,我再看看怎么解决。


下面内容是另外一个文件,因此部分代码重复

模型的使用代码

模型训练好了之后要实际应用。对于模型部署有很多成熟的方案,如 Nvidia 的 TensorRT, Intel 的 OpenVINO 等,都可以做模型的高效部署,这里限于篇幅不涉及相关内容。

在模型训练过程中,也可以使用使用框架提供的 API 做模型的简单部署以方便开发。

import os

import mxnet as mx

mx.__version__
'1.6.0'


首先要加载模型的标签用于展示,因为我们训练的时候就已经生成了标签文件,这里直接用写好的代码就可以。

if os.path.exists("labels.txt"):
    with open("labels.txt") as inf:
        classes = [l.strip() for l in inf]
else:
    classes = os.listdir("worddata/train/")
    with open("labels.txt", "w") as of:
        of.write("\r\n".join(classes))

接着是模型的定义,这里直接将训练中使用的模型代码拿来即可。

class MyModel(mx.gluon.nn.HybridBlock):
    def __init__(self):
        super(MyModel, self).__init__()
        # 模型有两个主要部分,特征提取层和分类器

        # 这里是特征提取层
        self.feature = mx.gluon.nn.HybridSequential()

        self.feature.add(self.conv(64))
        self.feature.add(self.conv(64, add_pooling=True))

        self.feature.add(self.conv(128))
        self.feature.add(self.conv(128, add_pooling=True))

        self.feature.add(self.conv(256))
        self.feature.add(self.conv(256))
        self.feature.add(self.conv(256, add_pooling=True))
        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512, add_pooling=True))

        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512))
        self.feature.add(self.conv(512, add_pooling=True))
        self.feature.add(mx.gluon.nn.GlobalAvgPool2D())
        self.feature.add(mx.gluon.nn.Flatten())

        self.feature.add(mx.gluon.nn.Dense(4096, activation="relu"))
        self.feature.add(mx.gluon.nn.BatchNorm())

        self.feature.add(mx.gluon.nn.Dense(4096, activation="relu"))
        self.feature.add(mx.gluon.nn.BatchNorm())

        self.feature.add(mx.gluon.nn.Dropout(0.5))
        # 这个简单的机构是分类器
        self.pred = mx.gluon.nn.Dense(100)

    def conv(self, filters, add_pooling=False):
        # 模型大量使用重复模块构建,
        # 这里将重复模块提取出来,简化模型构建过程
        model = mx.gluon.nn.HybridSequential()
        model.add(mx.gluon.nn.Conv2D(filters, 3, padding=1, activation="relu"))
        model.add(mx.gluon.nn.BatchNorm())

        if add_pooling:
            model.add(mx.gluon.nn.MaxPool2D(strides=2))
        return model

    def hybrid_forward(self, F, x):
        # call 用来定义模型各个结构之间的运算关系

        x = self.feature(x)
        return self.pred(x)

有了模型的定义之后,我们可以加载训练好的模型,跟模型训练的时候类似,我们可以直接加载模型训练中的 checkpoint。

ctx = mx.gpu(1)
model = MyModel()
model.initialize()

model.collect_params().reset_ctx(ctx)  # 丢到GPU上运行
model.hybridize()

model

import os

if os.path.exists("model.params"):
    # 检查 checkpoint 是否存在
    # 如果存在,则加载 checkpoint
    model.load_parameters("model.params")

    # 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,
    # 手动选择准确率最高的某次 checkpoint 进行加载。
    print("model lodaded")
model lodaded

对于数据,我们需要直接处理图片,因此这里导入一些图片处理的库和数据处理的库

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

直接打开某个图片

img = Image.open("worddata/validation/从/116e891836204e4e67659d2b73a7e4780a37c301.jpg")

plt.imshow(img, cmap="gray")

mxnet 图像分类模型训练教程

需要注意,模型在训练的时候,我们对数据进行了一些处理,在模型使用的时候,我们要对数据做一样的处理,如果不做的话,模型最终的结果会出现不可预料的问题。

img = img.resize((128, 128))
img = np.array(img) / 255
img.shape
(128, 128)


模型对图片数据的运算其实很简单,一行代码就可以。

这里需要注意模型处理的数据是 4 维的,而上面的图片数据实际是 2 维的,因此要对数据进行维度的扩充。同时模型的输出是 2 维的,带 batch ,所以需要压缩一下维度。
pred = np.squeeze(
    mx.nd.softmax(
        model(mx.nd.array(img[np.newaxis, np.newaxis, :, :]).as_in_context(ctx))
    ).asnumpy()
)
pred_idx = pred.argsort()[-5:]

print([pred[idx] for idx in pred_idx])
print([classes[idx] for idx in pred_idx])
[8.337523e-05, 0.0002217273, 0.0006491087, 0.0008851457, 0.99768066]
['作', '流', '遂', '夜', '从']

这里只给出了 top5 的结果,可以看到,准确率还是不错的。

上一篇:visual studio 2022新建第一个c++项目编译失败,提示E696,无法打开ctype.h errno.h float.h stdio.h


下一篇:《JavaScript高级程序设计》读书笔记(九):本地对象Array