1、导入相关功能包
import torch
from torch import nn
from torchsummary import summary
from tensorboardX import SummaryWriter
2、定义Affine模块
初始化α=1, β=0。
class Affine(nn.Module):
def __init__(self, channel):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, 1, channel))
self.beta = nn.Parameter(torch.zeros(1, 1, channel))
def forward(self, x):
return x * self.alpha + self.beta
3、定义PreAffinePostLayerScale模块
参考:Going deeper with Image Transformers
class PreAffinePostLayerScale(nn.Module): # https://arxiv.org/abs/2103.17239
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.affine = Affine(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.affine(x)) * self.scale + x
就是实现了以下部分:
4、构建resMLP网络
class ResMLP(nn.Module):
def __init__(self, dim=128, image_size=256, patch_size=16, expansion_factor=4, depth=2, class_num=1000):
super().__init__()
self.flatten = Rearange(image_size, patch_size) # Rearange(image_size=256, patch_size=16)
num_patches = (image_size // patch_size) ** 2
wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)
self.embedding = nn.Linear((patch_size ** 2) * 3, dim)
self.mlp = nn.Sequential()
for i in range(depth):
self.mlp.add_module('fc1_%d' % i, wrapper(i, nn.Conv1d(patch_size ** 2, patch_size ** 2, 1)))
# nn.Conv1d(patch_size ** 2 = 256, patch_size ** 2 = 256, 1)
self.mlp.add_module('fc1_%d' % i, wrapper(i, nn.Sequential(
nn.Linear(dim, dim * expansion_factor),
nn.GELU(),
nn.Linear(dim * expansion_factor, dim)
)))
self.aff = Affine(dim)
self.classifier = nn.Linear(dim, class_num)
self.softmax = nn.Softmax(1)
def forward(self, x):
y = self.flatten(x)
y = self.embedding(y)
# a = y.shape
y = self.mlp(y)
y = self.aff(y)
y = torch.mean(y, dim=1) # bs,dim
out = self.softmax(self.classifier(y))
return out
网络结构如下:
5、测试网络
# 测试resMLP
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = gMLPForImageClassification(image_size=256, patch_size=16, in_channels=3, num_classes=1000, d_model=256,
# d_ffn=512, seq_len=256, num_layers=1, ).to(device)
model = ResMLP(dim=128, image_size=256, patch_size=16, class_num=1000).to(device)
summary(model, (3, 256, 256)) # [2,3,256,256]
inputs = torch.Tensor(2, 3, 256, 256)
inputs = inputs.to(device)
print(inputs.shape)
# 将model保存为graph
with SummaryWriter(log_dir='logs', comment='model') as w:
w.add_graph(model, (inputs,))
print("success")
得到输出如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Rearange-1 [-1, 256, 768] 0
Linear-2 [-1, 256, 128] 98,432
Affine-3 [-1, 256, 128] 0
Linear-4 [-1, 256, 512] 66,048
GELU-5 [-1, 256, 512] 0
Linear-6 [-1, 256, 128] 65,664
PreAffinePostLayerScale-7 [-1, 256, 128] 0
Affine-8 [-1, 256, 128] 0
Linear-9 [-1, 1000] 129,000
Softmax-10 [-1, 1000] 0
================================================================
Total params: 359,144
Trainable params: 359,144
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 4.77
Params size (MB): 1.37
Estimated Total Size (MB): 6.89
----------------------------------------------------------------
torch.Size([2, 3, 256, 256])
success
通过TensorboardX查看网络具体结构: