复现 Mlp-Mixer(pytorch and keras)

文章目录


前言

mlp-mixer,Google又提出的一种基于感知机的网络。尽管CNN 已经在计算机视觉上取得很好的效果,最近提出来的基于Attention,以Vision Transformer 为首的神经网络已经 “杀疯” CV界,但是Google的大佬们认为CNN 和Attention 也不是必须的,于是就提出了mlp-mixer,在分类任务上也达到了很好的效果。但是网络的提出却早到了CNN之父LeCun的“教育”。因为网络的第一层(embedding时)却用到了卷积。
复现 Mlp-Mixer(pytorch and keras)
LeCun认为,这不过是 一个卷积核为1x1的卷积网络罢了。我在复现过程中,没有找到一个合适的数据集用该网络取得一个很好的效果。所以这里不附上训练的代码了。

项目 链接
论文 链接

博客中给出的就是网络结构的全部代码。

一、mlp-mixer原理介绍

1.网络结构

复现 Mlp-Mixer(pytorch and keras)
从上图我们可以看出,mlp-mixer 首先使用图片分成很多个小正方形的patch,每个patch的大小定义为patch_size。论文中实现这一步骤使用的是前面提到的卷积,卷积核的大小和步长均patch_size。至于卷积核的输出通道数,就是自己拍一拍脑袋就决定了,例如512 768 1024等。
网络不再使用传统的RELU激活函数,而是使用了GELU激活函数。
将图片分成小块后,在将它转换为一维结构。如图:
复现 Mlp-Mixer(pytorch and keras)
然后将每一个patch进行转换,如下图所示:
复现 Mlp-Mixer(pytorch and keras)
通过这样一种方式呢,就将一张图片转换为了一个大矩阵,就可以输入到MixerLayer 中进行计算啦。

2. MixerLayer

MixerLayer的结构如下图所示:
复现 Mlp-Mixer(pytorch and keras)
我们看一下论文里给出的公示:
复现 Mlp-Mixer(pytorch and keras)
MLP 是两个全连接层的感知机,W1,W2,对应token_mixer中两个全连接的权重,W3,W4则表示channel_mixer两个全连接的权重。σ表示GELU激活函数。那么公示就很简单了,输入X经过LN,再乘以W1,再经过激活函数后乘以W2,再加上X。第二个公式也是相同的计算过程。
大佬们都喜欢吧简单的问题复杂化吗?
将前面通过编码得到的矩阵经过Layer Norm 在将矩阵进行旋转(T 表示旋转)连接MLP1,MLP1 就是文章token_mixer 用来寻找像素与像素之间的关系,其中,MLP1中的权值共享。计算完之后,再将矩阵旋转回来,通过Layer Norm 后再接一个channel_mixer 用于寻找通道与通道之间的关系。其中MixerLayer 还启用了ResNet中的跨连结构,跨连结构的作用可以参考[优雅的复现ResNet50],看到这里,是不是感觉它跟卷积的原理很类似。
从上图可以看出MixerLayer的输入维度和输出维度相同,并且通过MLP的方式来寻找图片像素与像素,通道与通道的关系。
这就是MLP-MIXER的网络结构了,目前的了解,没有开源的pytorch或者TensorFlow 预训练的权重。官方给出的代码和权重是基于JAX的。
谁找到了权重 给我说一下 嘤嘤嘤~~~

3.MLP-Mixer模型类型

复现 Mlp-Mixer(pytorch and keras)

文章中给出的模型参数列表,Patch resolution 就是patch 的长宽。Hidden size 就是映射成前面提到的大矩阵的维度,Squence length 是计算后的结果。
以上图红色部分为例,输入图像大小为224*224,
然后分成的块大小为32*32.
那么 (224*224)\(32*32)=7*7,Squence length 就为 49。Dc和Ds分别表示token_mixer和channel_mixer 中全连接层节点的个数。
这就是MLP-Mixer的全部过程了。

二、网络实现

1.pytorch复现

实现的难点在于,矩阵旋转,我们使用einops中的Rearrange实现矩阵旋转。还需要使用torchsummary 来查看网络结构。安装:

pip install einops
pip install torchsummary

首先我们来实现MLP 也就是FeedForward:

#定义多层感知机
class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,dropout=0.):
        super().__init__()
        self.net=nn.Sequential(
            #由此可以看出 FeedForward 的输入和输出维度是一致的
            nn.Linear(dim,hidden_dim),
            #激活函数
            nn.GELU(),
            #防止过拟合
            nn.Dropout(dropout),
            #重复上述过程
            nn.Linear(hidden_dim,dim),

            nn.Dropout(dropout)
        )
    def forward(self,x):
        x=self.net(x)
        return x
#测试多层感知机
# mlp=FeedForward(10,20,0.4).to(device)
# summary(mlp,input_size=(10,))

实现过程很简单,就是全连接结构
接着我们来实现Mixer Block,里面包含了 token_mixer 和channel_mixer,还有矩阵转置。

#使用Rearrange 实现旋转
Rearrange('b n d -> b d n') #这里是[batch_size, num_patch, dim] -> [batch_size, dim, num_patch]

实现如下:

class MixerBlock(nn.Module):
    def __init__(self,dim,num_patch,token_dim,channel_dim,dropout=0.):
        super().__init__()
        self.token_mixer=nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'),
            FeedForward(num_patch,token_dim,dropout),
            Rearrange('b d n -> b n d')

         )
        self.channel_mixer=nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim,channel_dim,dropout)
        )
    def forward(self,x):

        x=x+self.token_mixer(x)

        x=x+self.channel_mixer(x)

        return x

#测试mixerblock
# x=torch.randn(1,196,512)
# mixer_block=MixerBlock(512,196,32,32)
# x=mixer_block(x)
# print(x.shape)

更具上述定义好的网络零件,就可以实现我们最终的主网络mlp-mixer:

class MLPMixer(nn.Module):
    def __init__(self,in_channels,dim,num_classes,patch_size,image_size,depth,token_dim,channel_dim,dropout=0.):
        super().__init__()
        assert image_size%patch_size==0
        self.num_patches=(image_size//patch_size)**2
        #embedding 操作,看见没用卷积来分成一小块一小块的
        self.to_embedding=nn.Sequential(         Conv2d(in_channels=in_channels,out_channels=dim,kernel_size=patch_size,stride=patch_size),
            Rearrange('b c h w -> b (h w) c')
        )
        self.mixer_blocks=nn.ModuleList([])
        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim,self.num_patches,token_dim,channel_dim,dropout))
        self.layer_normal=nn.LayerNorm(dim)

        self.mlp_head=nn.Sequential(
            nn.Linear(dim,num_classes)
        )
    def forward(self,x):
        x=self.to_embedding(x)
        for mixer_block in self.mixer_blocks:
            x=mixer_block(x)
        x=self.layer_normal(x)
        x=x.mean(dim=1)

        x=self.mlp_head(x)

        return x
#测试Mlp-Mixer
if __name__ == '__main__':    
      model = MLPMixer(in_channels=3, dim=512, num_classes=1000, patch_size=16, image_size=224, depth=1, token_dim=256,
                     channel_dim=2048).to(device)
      summary(model,(3,224,224))

depth=1,网络深度等于1 这样方便我们看整体结构,最后的运行结果:
复现 Mlp-Mixer(pytorch and keras)

2.keras 复现

keras的复现过程与pytorch类似,但有几个注意的地方,网络中使用的GELU 激活函数在tensorflow>=2.4才可以使用,conda 国内镜像很难安装tensorflow 2.4 以上的GPU 版本。如果要使用的话tensorflow2.4一下版本的话,自定义GELU激活函数如下:

def gelu(x):
    cdf = 0.5 * (1.0 + tf.tanh(
      (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf

在keras中使用类的方法定义自己的网络层时,需要重写 get_config 函数 不然模型无法保存。keras中借助Permute 层实现转置。
keras的完整实现如下:

from tensorflow import keras
import tensorflow as tf
import numpy as np
from keras import backend as K
from tensorflow.keras.layers import (
    Add,
    Dense,
    Conv2D,
    GlobalAveragePooling1D,
    Flatten,
    Layer,
    LayerNormalization,
    Permute,
    Softmax,
    Activation,
)
class MlpBlock(Layer):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        activation=None,
        **kwargs
    ):
        super(MlpBlock, self).__init__(**kwargs)

        if activation is None:
            activation = keras.activations.gelu

        self.dim = dim
        self.hidden_dim = dim
        self.dense1 = Dense(hidden_dim)
        self.activation = Activation(activation)
        self.dense2 = Dense(dim)

    def call(self, inputs):
        x = inputs
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        return x

    def compute_output_shape(self, input_signature):
        return (input_signature[0], self.dim)

    def get_config(self):
        config = super(MlpBlock, self).get_config()
        config.update({
            'dim': self.dim,
            'hidden_dim': self.hidden_dim
        })
        return config


class MixerBlock(Layer):
    def __init__(
        self,
        num_patches: int,
        channel_dim: int,
        token_mixer_hidden_dim: int,
        channel_mixer_hidden_dim: int = None,
        activation=None,
        **kwargs
    ):
        super(MixerBlock, self).__init__(**kwargs)

        self.num_patches = num_patches
        self.channel_dim = channel_dim
        self.token_mixer_hidden_dim = token_mixer_hidden_dim
        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim
        self.activation = activation

        if activation is None:
            self.activation = keras.activations.gelu

        if channel_mixer_hidden_dim is None:
            channel_mixer_hidden_dim = token_mixer_hidden_dim

        self.norm1 = LayerNormalization(axis=1)
        self.permute1 = Permute((2, 1))
        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')

        self.permute2 = Permute((2, 1))
        self.norm2 = LayerNormalization(axis=1)
        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')

        self.skip_connection1 = Add()
        self.skip_connection2 = Add()

    def call(self, inputs):
        x = inputs
        skip_x = x
        x = self.norm1(x)
        x = self.permute1(x)
        x = self.token_mixer(x)

        x = self.permute2(x)

        x = self.skip_connection1([x, skip_x])
        skip_x = x

        x = self.norm2(x)
        x = self.channel_mixer(x)

        x = self.skip_connection2([x, skip_x])  # TODO need 2?

        return x

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = super(MixerBlock, self).get_config()
        config.update({
            'num_patches': self.num_patches,
            'channel_dim': self.channel_dim,
            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,
            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,
            'activation': self.activation,
        })
        return config


def MlpMixerModel(
        input_shape: int,
        num_classes: int,
        num_blocks: int,
        patch_size: int,
        hidden_dim: int,
        tokens_mlp_dim: int,
        channels_mlp_dim: int = None,
        use_softmax: bool = False,
):
    height, width, _ = input_shape

    if channels_mlp_dim is None:
        channels_mlp_dim = tokens_mlp_dim

    num_patches = (height*width)//(patch_size**2)  # TODO verify how this behaves with same padding

    inputs = keras.Input(input_shape)
    x = inputs

    x = Conv2D(hidden_dim,
               kernel_size=patch_size,
               strides=patch_size,
               padding='same',
               name='projector')(x)

    x = keras.layers.Reshape([-1, hidden_dim])(x)

    for _ in range(num_blocks):
        x = MixerBlock(num_patches=num_patches,
                       channel_dim=hidden_dim,
                       token_mixer_hidden_dim=tokens_mlp_dim,
                       channel_mixer_hidden_dim=channels_mlp_dim)(x)

    x = Flatten()(x)  # TODO verify this global average pool is correct choice here

    x = LayerNormalization(name='pre_head_layer_norm')(x)
    x = Dense(num_classes, name='head')(x)

    if use_softmax:
        x = Softmax()(x)
    return keras.Model(inputs, x)

总结

代码复现过程,参考了论文地址给出的GitHub的链接,自己手撕代码的能力还是比较弱,不足之处就是没有使用数据集去训练它并达到一个不错的效果。后续如果有结果,会更新这篇博客,在训练模型时,调参过程是真的累,应该还是缺少理论知识的原因。后续会学习如何进行调参。
创作不易,点赞鼓励。
最后说一句,pytorch真香~~~~~~,前面的博客LeNetAlexNet已添加pytorch实现。

上一篇:数组学习系列1-VBA数组详解(2)


下一篇:Access vba实例