Spatial Transformer Network with image classification

In this tutorial, you will learn how to augment your network using a visual attention mechanism called spatial transformer networks.

Spatial Transformer Network with image classification

Spatial transformer networks are a generalization of differentiable attention to any spatial transformation. Spatial transformer networks (STN for short) allow a neural network to learn how to perform spatial transformations on the input image in order to enhance the geometric invariance of the model. For example, it can crop a region of interest, scale and correct the orientation of an image. It can be a useful mechanism because CNNs are not invariant to rotation and scale and more general affine transformations.

解读: 上面提到的 crop a region of interest, scale and correct the orientation of an image. because CNNs are not invariant to rotation and scale and more general affine transformations. 这两句话,在我看来其实有一部分数据增强可以解决,所以两者在这块儿貌似有点小重合。

Spatial transformer networks boils down to three main components:

  • The localization network is a regular CNN which regresses the transformation parameters. The transformation is never learned explicitly from this dataset, instead the network learns automatically the spatial transformations that enhances the global accuracy.
  • The grid generator generates a grid of coordinates in the input image corresponding to each pixel from the output image.
  • The sampler uses the parameters of the transformation and applies it to the input image.

Spatial Transformer Network with image classification

解读: 此处说明了 STN 的构成主要分为 3 个部分,分别是 Localization network、Grid generator、Sampler。从上图也可以看出来,STN 的参数学习是一种端到端的方式。另外再提一点,如果你看过源码的话,可以发现 STN 能处理的输入分辨率是比较小的,如果分辨率比较大,那么也意味着我们需要扩张 STN 的网络,这个开销就有点大了,可能不太合适,不过可以考虑放在 ResNet 等网络的后半部分,不过还没实验是否能带来准确率提升。

下图是 MNIST 过后,展示的原始图片和转换后的图片。

Spatial Transformer Network with image classification

下图是 Cifar10 训练过后,展示的原始图片和转换后的图片。

Spatial Transformer Network with image classification

下面的图片是 Spatial Transformer Network 在 MNIST 数据集上的 test_acc = 99

Spatial Transformer Network with image classification

下面的图片是 Spatial Transformer Network 在 Cifar10 上的 test_acc = 65

Spatial Transformer Network with image classification

下面的图片是去掉 sth 之后,正常的网络训练结果,而且没有加数据增强,test_acc = 62

Spatial Transformer Network with image classification

下面的图片是去掉 sth 之后,添加了 AutoAugment 的训练结果,test_acc = 60

Spatial Transformer Network with image classification

这里让我有点意外的是在 Cifar10 上,添加了 AutoAugment 数据增强之后,竟然效果变差了,猜测和网络过于简单以及数据集图片的分辨率过小有关系,诶,还是那句话,有些时候某些手段到底有没有用,只有实验之后才知道。

下面的图片是在 Spatial Tranformer Network 的基础上,在卷积层部分添加了 BatchNorm2d,效果进一步提升,特别是这么简单的一个网络,竟然也有 test_acc = 72,不敢想,不敢想啊。

Spatial Transformer Network with image classification

代码地址: https://github.com/MaoXianXin/PycharmProjects

上一篇:SSRN:Spectral-Spatial residual network for HSI classification


下一篇:RabbitMQ学习系列二:.net 环境下 C#代码使用 RabbitMQ 消息队列