16、PyTorch中进行卷积残差模块算子融合

文章目录

  • 1. 1x1卷积核-> 3x3卷积核
  • 2. 输入x --> 3x3卷积核,无变化
  • 3. 代码

1. 1x1卷积核-> 3x3卷积核

假设我们有一个1x1的卷积核,需要通过填充变为一个3x3的卷积核,实现的是像素之间无关联
[ 4 ] → [ 0 0 0 0 4 0 0 0 0 ] \begin{equation} \begin{bmatrix}4\end{bmatrix}\to \begin{bmatrix} 0&0&0\\\\ 0&4&0\\\\ 0&0&0\end{bmatrix} \end{equation} [4] 000040000

2. 输入x --> 3x3卷积核,无变化

我们希望有一个x,用3x3的卷积核表示后依然不变,那么首先是3x3的卷积核本身移动过程中不会改变像素值,像素之间不融合,其次是空间中不融合,假设我们有一个卷积定义如下
c o n 2 d ( 2 , 2 , 3 , p a d d i n g = " s a m e " ) \begin{equation} con2d(2,2,3,padding="same") \end{equation} con2d(2,2,3,padding="same")
可得: 输出通道为2,输入通道为2,卷积核大小为3,padding=“same”表示卷积核图像不变
卷积权重大小为(2,2,3,3)

  • 可以把(2,2,3,3)简单拆分成两个部分,第一个为(2,3,3)的卷积核矩阵,实现的是卷积滑动操作,(2,2)表示的将输入通道数2转换成输出通道数2,那么一个2x2的矩阵,怎样才能够实现通道分离呢?一般就是对角矩阵,那么可以简单看做如下:
    [ a b b a ] ; a → [ 0 0 0 0 1 0 0 0 0 ] ; b → [ 0 0 0 0 0 0 0 0 0 ] \begin{equation} \begin{bmatrix} a&b\\\\ b&a \end{bmatrix};a\to \begin{bmatrix} 0&0&0\\\\ 0&1&0\\\\ 0&0&0 \end{bmatrix};b\to \begin{bmatrix} 0&0&0\\\\ 0&0&0\\\\ 0&0&0 \end{bmatrix} \end{equation} abba ;a 000010000 ;b 000000000
  • 这样就实现了通道上的分离和像素上的分离。

3. 代码

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :confusion_conv2d.py
# @Time      :2024/12/9 8:23
# @Author    :Jason Zhang
import torch
from torch import nn
import torch.nn.functional as F

if __name__ == "__main__":
    run_code = 0
    in_channels = 2
    out_channels = 2
    kernel_size = 3
    w = 9
    h = 9
    x = torch.ones(1, in_channels, w, h)  # input image size
    # 1. pytorch method 1.1 x --> image = 1,2,9,9 1.2 kernel --> conv2d -->2,2,3,3 --> 2,2,3,3 --> kernel_size =3x3
    # 1.3 compute . if we have out_chanel for the 2 ,and
    # each channel has 2 kernel with 3x3, total 2x2--> 4 counts for 3x3

    conv_2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same")
    print(f"conv_2d.weight.shape={conv_2d.weight.shape}")
    print(f"conv_2d.weight={conv_2d.weight}")
    conv_2d_pointwise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
    result1 = conv_2d(x) + conv_2d_pointwise(x) + x
    print(f"result1.shape=\n{result1.shape}")
    print(f"result1=\n{result1}")

    zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
    starts = torch.unsqueeze(F.pad(torch.ones(1, 1), (1, 1, 1, 1)), 0)
    print(f"zeros=\n{zeros}")
    print(f"zeros.shape={zeros.shape}")
    print(f"starts=\n{starts}")
    print(f"starts.shape=\n{starts.shape}")
    starts_zeros = torch.unsqueeze(torch.cat((starts, zeros), 0), 0)
    zeros_starts = torch.unsqueeze(torch.cat((zeros, starts), 0), 0)
    print(f"start_zeros=\n{starts_zeros}")
    print(f"start_zeros.shape=\n{starts_zeros.shape}")
    print(f"zeros_starts=\n{zeros_starts}")
    print(f"zeros_starts.shape=\n{zeros_starts.shape}")
    identity_weight = torch.cat((starts_zeros, zeros_starts), 0)
    identity_bias = torch.zeros(out_channels)
    print(f"identity=\n{identity_weight}")
    print(f"identity.shape=\n{identity_weight.shape}")
    test_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same")
    test_conv2d_weight = test_conv2d.weight
    test_conv2d_bias = test_conv2d.bias
    print(test_conv2d_weight.shape)
    print(test_conv2d_bias.shape)
    test_conv2d.weight = nn.Parameter(identity_weight)
    test_conv2d.bias = nn.Parameter(identity_bias)
    input_x = torch.randint(1, 10, (1, 2, 9, 9), dtype=torch.float)
    out_y = test_conv2d(input_x)
    print(f"input_x=\n{input_x}")
    print(f"out_y=\n{out_y}")
    check_out = torch.allclose(input_x, out_y)
    print(f"input_x is {check_out} same for out_y")

    point_wise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding="same")
    point_wise_weight = F.pad(point_wise.weight, (1, 1, 1, 1, 0, 0, 0, 0))
    point_wise_bias = point_wise.bias
    print(f"point_wise=\n{point_wise_weight}")
    print(f"point_wise.shape=\n{point_wise_weight.shape}")
    point_wise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                padding="same")
    point_wise_conv.weight = nn.Parameter(point_wise_weight)
    point_wise_conv.bias = nn.Parameter(point_wise_bias)
    point_wise_out = point_wise(input_x)
    print(f"point_wise_out=\n{point_wise_out}")
    point_3_wise_out = point_wise_conv(input_x)
    check_3 = torch.allclose(point_wise_out, point_3_wise_out)
    print(f"check_3 is {check_3} same for point_3_wise_out")
上一篇:数据可视化大屏UI组件库:B端科技感素材PSD