Gated Channel Transformation for Visual Recognition (GCT)
- 本文是复现的百度&Syndney投于CVPR2020的一篇关于Attention文章。它提出了一种通用且轻量型变化单元,该单元结合了归一化方法和注意力机制,易于分析通道之间的相互关系(竞争or协作)同时便于与网络本身参数联合训练。
github:https://github.com/z-x-yang/GCT
前言
- 作者指出,SENet使用全连接层(FC层)处理 channel-wise embeddings,这样会产生全连接层难以分析网络不同层间通道的关联性等问题。
- 因此作者提出了Gated Channel Transformation (GCT)模块,主要有如下设计:
- 1.使用一个 归一化模块 替换FC,对通道间的特征关系建模
- 2.通过门控对通道间的特征关系建模
- 3.通过归一化模块和门控机制,GCT模块可以捕获通道特征间的竞争和合作。
- 本文的主要创新点是提出了一个新的注意力机制,是一个Channel & Spatial attention,在各CV任务测试性能如下
相关代码
具体的网络结构如上图所示:
-
1.蓝色部分(Global context embedding):首先当特征图输入到注意力模块中后,没有采用全局池化的方式,因为全局池化在某些情况下会失效。作者在这里使用L2 norm进行了global context embeding
-
2.绿色部分(Channel Normalization):在这部分仍然使用L2 norm,√c用来避免当C比较大时,s^c的值过小。与 SE 的 FC层相比,该通道归一化方法计算量更小。
-
3.红色部分(Gating adaptation):这里设计了权重γ 和 偏置 β 来控制通道特征是否激活。当一个通道的特征权重 γc被正激活,GCT将促进这个通道的特征和其它通道的特征“竞争”。当一个通道的特征 γc 被负激活,GCT将促进这个通道的特征和其它通道的特征“合作”。
import paddle
import paddle.nn as nn
import cv2
class BasicConv(nn.Layer):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
relu=True,
bn=True,
bias_attr=False,
):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2D(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias_attr,
)
self.bn = (
nn.BatchNorm2D(out_planes, epsilon=1e-5, momentum=0.01)
if bn
else None
)
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ZPool(nn.Layer):
def forward(self, x):
#print(x.shape)#[4, 16, 512, 16][512, 1, 16][4, 1, 512, 16]
#print(paddle.max(x, 1).unsqueeze(1).shape)
#print(paddle.mean(x, 1).unsqueeze(1).shape)
return paddle.concat(
(paddle.max(x, 1).unsqueeze(1),
paddle.mean(x, 1).unsqueeze(1))
,axis=1)
class AttentionGate(nn.Layer):
def __init__(self):
super(AttentionGate, self).__init__()
kernel_size = 7
self.compress = ZPool()
self.conv = BasicConv(
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.conv(x_compress)
scale = paddle.nn.functional.sigmoid(x_out)
return x * scale
class TripletAttention(nn.Layer):
def __init__(self, no_spatial=False):
super(TripletAttention, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial = no_spatial
if not no_spatial:
self.hw = AttentionGate()
def forward(self, x):
x_perm1 = x.transpose([0, 2, 1, 3])
x_out1 = self.cw(x_perm1)
x_out11 = x_out1.transpose([0, 2, 1, 3])
x_perm2 = x.transpose([0, 3, 2, 1])
x_out2 = self.hc(x_perm2)
x_out21 = x_out2.transpose([0, 3, 2, 1])
if not self.no_spatial:
x_out = self.hw(x)
x_out = 1 / 3 * (x_out + x_out11 + x_out21)
else:
x_out = 1 / 2 * (x_out11 + x_out21)
return x_out
[1, 512, 16, 16]
验证
input size = 64,512,16,16 --> GCT --> output size = 64,512,16,16
import paddle
import paddle.nn.functional as F
import math
from paddle import nn
class GCT(nn.Layer):
def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
super(GCT, self).__init__()
self.alpha = self.create_parameter(shape=[1, num_channels, 1, 1],default_initializer=paddle.nn.initializer.Assign(paddle.ones([1, num_channels, 1, 1])))
self.gamma = self.create_parameter(shape=[1, num_channels, 1, 1],default_initializer=paddle.nn.initializer.Assign(paddle.zeros([1, num_channels, 1, 1])))
self.beta = self.create_parameter(shape=[1, num_channels, 1, 1],default_initializer=paddle.nn.initializer.Assign(paddle.zeros([1, num_channels, 1, 1])))
self.epsilon = epsilon
self.mode = mode
self.after_relu = after_relu
def forward(self, x):
if self.mode == 'l2':
embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
norm = self.gamma / (embedding.pow(2).mean(axis=1, keepdim=True) + self.epsilon).pow(0.5)
elif self.mode == 'l1':
if not self.after_relu:
_x = paddle.abs(x)
else:
_x = x
embedding = _x.sum((2,3), keepdim=True) * self.alpha
norm = self.gamma / (paddle.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
else:
print('Unknown mode!')
sys.exit()
gate = 1. + paddle.tanh(embedding * norm + self.beta)
return x * gate
if __name__=="__main__":
a = paddle.rand([64,512,16,16])
model = GCT(512)
a = model(a)
print(a.shape)
[64, 512, 16, 16]
对GCT性能进行验证
本次实验通过搭建一个ResNet18网络来验证性能,GCT模块插入位置如下。
GCT_ResNet18 搭建
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
self.conv1 = nn.Conv2D(
inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BottleneckBlock(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BottleneckBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
width = int(planes * (base_width / 64.)) * groups
self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2D(
width,
width,
3,
padding=dilation,
stride=stride,
groups=groups,
dilation=dilation,
bias_attr=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2D(
width, planes * self.expansion, 1, bias_attr=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
self.attention = GCT(planes * self.expansion)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.attention(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Layer):
def __init__(self,
block,
depth=50,
width=64,
num_classes=1000,
with_pool=True):
super(ResNet, self).__init__()
layer_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3]
}
layers = layer_cfg[depth]
self.groups = 1
self.base_width = width
self.num_classes = num_classes
self.with_pool = with_pool
self._norm_layer = nn.BatchNorm2D
self.inplanes = 64
self.dilation = 1
self.conv1 = nn.Conv2D(
3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = self._norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
if num_classes > 0:
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
1,
stride=stride,
bias_attr=False),
norm_layer(planes * block.expansion), )
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, Block, depth, pretrained, **kwargs):
model = ResNet(Block, depth, **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
param = paddle.load(weight_path)
model.set_dict(param)
return model
def resnet18(pretrained=False, **kwargs):
return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)
def resnet34(pretrained=False, **kwargs):
return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)
def resnet50(pretrained=False, **kwargs):
return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)
def resnet101(pretrained=False, **kwargs):
return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)
def resnet152(pretrained=False, **kwargs):
return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)
def wide_resnet50_2(pretrained=False, **kwargs):
kwargs['width'] = 64 * 2
return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs)
def wide_resnet101_2(pretrained=False, **kwargs):
kwargs['width'] = 64 * 2
return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained,
**kwargs)
Ta_res18 = resnet18(num_classes=10)
paddle.Model(Ta_res18).summary((1,3,224,224))
Cifar10数据准备
import paddle.vision.transforms as T
from paddle.vision.datasets import Cifar10
paddle.set_device('gpu')
# 数据准备
transform = T.Compose([
T.Resize(size=(224,224)),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
T.ToTensor()
])
train_dataset = Cifar10(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
ResNet18在Cifar10训练
# 模型准备
res18 = paddle.vision.models.resnet18(num_classes=10)
res18.train()
# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=res18.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()
res50_loss = []
res50_acc = []
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader):
inputs = data[0]
labels = data[1].unsqueeze(1)
predicts = res18(inputs)
loss = loss_fn(predicts, labels)
acc = paddle.metric.accuracy(predicts, labels)
loss.backward()
if batch_id % 100 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
if batch_id % 20 == 0:
res50_loss.append(loss.numpy())
res50_acc.append(acc.numpy())
optim.step()
optim.clear_grad()
GCT_ResNet18在Cifar10数据集训练
# 模型准备
gct_res18 = resnet18(num_classes=10)
gct_res18.train()
# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=gct_res18.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()
gct_res18_loss = []
gct_res18_acc = []
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader):
inputs = data[0]
labels = data[1].unsqueeze(1)
predicts = gct_res18(inputs)
loss = loss_fn(predicts, labels)
acc = paddle.metric.accuracy(predicts, labels)
loss.backward()
if batch_id % 100 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
if batch_id % 20 == 0:
gct_res18_loss.append(loss.numpy())
gct_res18_acc.append(acc.numpy())
optim.step()
optim.clear_grad()
绘制 ResNet18 和 GCT_ResNet18 训练曲线
import matplotlib.pyplot as plt
plt.figure(figsize=(18,12))
plt.subplot(211)
plt.xlabel('iter')
plt.ylabel('loss')
plt.title('train loss')
x=range(len(gct_res18_loss))
plt.plot(x,res18_loss,color='b',label='ResNet18')
plt.plot(x,gct_res18_loss,color='r',label='ResNet18 + GCT')
plt.legend()
plt.grid()
plt.subplot(212)
plt.xlabel('iter')
plt.ylabel('acc')
plt.title('train acc')
x=range(len(gct_res18_acc))
plt.plot(x, res18_acc, color='b',label='ResNet18')
plt.plot(x, gct_res18_acc, color='r',label='ResNet18 + GCT')
plt.legend()
plt.grid()
plt.show()
- 模型训练总结:通过曲线可以看到加入了GCT注意力机制之后模型的鲁棒性更好,识别精准度也有一定的提高,说明GCT注意力的有效性。
总结
- 论文还有一个有趣的地方是作者做了个实验,来分析在RestNet50中,γ的变化对网络的影响。可以看出:
- 1.在网络的浅层,γ的值比较小,普遍在0以下,说明特征间中合作关系;
- 2.在网络的深层,γ的值就在增大,增长到0以上,说明特征间是竞争关系,有助于分类。
特别感谢:仰世而来丶(本文参考了https://aistudio.baidu.com/aistudio/projectdetail/1884947?channelType=0&channel=0)
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.