自建网络
class TEModule(nn.Module):
expansion = 4
def __init__(self, channel, reduction=16, n_segment=8):
expansion=4
super(TEModule, self).__init__()
self.channel = channel
self.reduction = reduction
self.n_segment = n_segment
#backbone--conv2d--bn2
self.conv1 = nn.Conv2d(
in_channels=self.channel,
out_channels=self.channel,
kernel_size=1,
bias=False
)
self.bn1 = nn.BatchNorm2d(num_features=self.channel)
#Branch1--conv1d(T)--conv2d(HW)
self.conv_b1_1 = nn.Conv1d(
self.channel, self.channel,
kernel_size=1,bias=False
)
self.conv_b1_2 = nn.Conv2d(
in_channels=self.channel,
out_channels=self.channel,
kernel_size=3, padding=1,
bias=False
)
#Branch1--conv2d(HW)--conv2d(channel)--
self.conv_b2_0 = nn.Conv2d(
in_channels=self.channel,
out_channels=self.channel,
kernel_size=1,
bias=False
)
self.conv_b2_1 = nn.Conv2d(
in_channels=self.channel//2,
out_channels=self.channel//2,
kernel_size=3, padding=1,
bias=False
)
#STM网络中是用帧帧相减得到帧间差异
self.conv_b2_2 = nn.Conv2d(
in_channels=self.channel//4,
out_channels=self.channel//4,
kernel_size=3, padding=1,
bias=False
)
self.conv_b2_3 = nn.Conv2d(
in_channels=self.channel//4,
out_channels=self.channel//4,
kernel_size=3, padding=1,
bias=False
)
self.conv_b2_4 = nn.Conv2d(
in_channels=self.channel,
out_channels=self.channel,
kernel_size=1,
bias=False
)
def forward(self, x):
nt, c, h, w = x.size()
n_batch = nt // self.n_segment
x = x.view(n_batch, self.n_segment, c, h, w)
x1 = self.conv1(x)
x1 = self.bn1(x1)
#branch1
b1 = x1.permute([0, 3, 4, 2, 1])
b1_0 = self.conv_b1_1(b1)
b1_0 = b1_0.permute([0, 4, 3, 1, 2])
b1_0 = self.conv_b1_2(b1_0)
b1_0 = b1_0.view(nt, c, h, w)
#branch2
b2_00 = self.conv_b2_0(x1)
b2_01 = self.conv_b2_1(b2_00[n_batch, self.n_segment, :self.channel//2, h, w])
b2_02 = b2_00[n_batch, self.n_segment, self.channel//2:3*self.channel/4, h, w]
b2_02_1 = self.conv_b2_2(b2_02)
b2_03_0 = b2_00[n_batch, self.n_segment, -self.channel//3, h, w]
b2_03_1 = self.conv_b2_3(b2_03_0)
b2_03 = torch.add(b2_02_1,b2_03_1)
b2 = torch.cat(b2_01, b2_02_1, b2_03)
b = torch.add(b1_0, b2)
b = b.view(nt, c, h, w)
b = self.conv_b2_4(b)
return b