T2N

自建网络

class TEModule(nn.Module):
    expansion = 4
    def __init__(self, channel, reduction=16, n_segment=8):
        expansion=4
        super(TEModule, self).__init__()
        self.channel = channel
        self.reduction = reduction
        self.n_segment = n_segment
        #backbone--conv2d--bn2
        self.conv1 = nn.Conv2d(
            in_channels=self.channel,
            out_channels=self.channel,
            kernel_size=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(num_features=self.channel)
        #Branch1--conv1d(T)--conv2d(HW)
        self.conv_b1_1 = nn.Conv1d(
            self.channel, self.channel,
            kernel_size=1,bias=False
        )
        self.conv_b1_2 = nn.Conv2d(
            in_channels=self.channel,
            out_channels=self.channel,
            kernel_size=3, padding=1,
            bias=False
        )
        #Branch1--conv2d(HW)--conv2d(channel)--
        self.conv_b2_0 = nn.Conv2d(
            in_channels=self.channel,
            out_channels=self.channel,
            kernel_size=1,
            bias=False
        )
        self.conv_b2_1 = nn.Conv2d(
            in_channels=self.channel//2,
            out_channels=self.channel//2,
            kernel_size=3, padding=1,
            bias=False
        )
        #STM网络中是用帧帧相减得到帧间差异
        self.conv_b2_2 = nn.Conv2d(
            in_channels=self.channel//4,
            out_channels=self.channel//4,
            kernel_size=3, padding=1,
            bias=False
        )

        self.conv_b2_3 = nn.Conv2d(
            in_channels=self.channel//4,
            out_channels=self.channel//4,
            kernel_size=3, padding=1,
            bias=False
        )

        self.conv_b2_4 = nn.Conv2d(
            in_channels=self.channel,
            out_channels=self.channel,
            kernel_size=1,
            bias=False
        )

    def forward(self, x):
        nt, c, h, w = x.size()
        n_batch = nt // self.n_segment
        x = x.view(n_batch, self.n_segment, c, h, w)
        x1 = self.conv1(x)
        x1 = self.bn1(x1)

        #branch1
        b1 = x1.permute([0, 3, 4, 2, 1])
        b1_0 = self.conv_b1_1(b1)
        b1_0 = b1_0.permute([0, 4, 3, 1, 2])
        b1_0 = self.conv_b1_2(b1_0)
        b1_0 = b1_0.view(nt, c, h, w)

        #branch2
        b2_00 = self.conv_b2_0(x1)
        b2_01 = self.conv_b2_1(b2_00[n_batch, self.n_segment, :self.channel//2, h, w])
        b2_02 = b2_00[n_batch, self.n_segment, self.channel//2:3*self.channel/4, h, w]
        b2_02_1 = self.conv_b2_2(b2_02)
        b2_03_0 = b2_00[n_batch, self.n_segment, -self.channel//3, h, w]
        b2_03_1 = self.conv_b2_3(b2_03_0)
        b2_03 = torch.add(b2_02_1,b2_03_1)

        b2 = torch.cat(b2_01, b2_02_1, b2_03)
        b = torch.add(b1_0, b2)
        b = b.view(nt, c, h, w)
        b = self.conv_b2_4(b)
        return b

上一篇:Improvise a Jazz Solo with an LSTM Network


下一篇:清北学堂2020.11.26笔记