文章目录
- 1. 1x1卷积核-> 3x3卷积核
- 2. 输入x --> 3x3卷积核,无变化
- 3. 代码
1. 1x1卷积核-> 3x3卷积核
假设我们有一个1x1的卷积核,需要通过填充变为一个3x3的卷积核,实现的是像素之间无关联
[
4
]
→
[
0
0
0
0
4
0
0
0
0
]
\begin{equation} \begin{bmatrix}4\end{bmatrix}\to \begin{bmatrix} 0&0&0\\\\ 0&4&0\\\\ 0&0&0\end{bmatrix} \end{equation}
[4]→
000040000
2. 输入x --> 3x3卷积核,无变化
我们希望有一个x,用3x3的卷积核表示后依然不变,那么首先是3x3的卷积核本身移动过程中不会改变像素值,像素之间不融合,其次是空间中不融合,假设我们有一个卷积定义如下
c
o
n
2
d
(
2
,
2
,
3
,
p
a
d
d
i
n
g
=
"
s
a
m
e
"
)
\begin{equation} con2d(2,2,3,padding="same") \end{equation}
con2d(2,2,3,padding="same")
可得: 输出通道为2,输入通道为2,卷积核大小为3,padding=“same”表示卷积核图像不变
卷积权重大小为(2,2,3,3)
- 可以把(2,2,3,3)简单拆分成两个部分,第一个为(2,3,3)的卷积核矩阵,实现的是卷积滑动操作,(2,2)表示的将输入通道数2转换成输出通道数2,那么一个2x2的矩阵,怎样才能够实现通道分离呢?一般就是对角矩阵,那么可以简单看做如下:
[ a b b a ] ; a → [ 0 0 0 0 1 0 0 0 0 ] ; b → [ 0 0 0 0 0 0 0 0 0 ] \begin{equation} \begin{bmatrix} a&b\\\\ b&a \end{bmatrix};a\to \begin{bmatrix} 0&0&0\\\\ 0&1&0\\\\ 0&0&0 \end{bmatrix};b\to \begin{bmatrix} 0&0&0\\\\ 0&0&0\\\\ 0&0&0 \end{bmatrix} \end{equation} abba ;a→ 000010000 ;b→ 000000000 - 这样就实现了通道上的分离和像素上的分离。
3. 代码
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName :confusion_conv2d.py
# @Time :2024/12/9 8:23
# @Author :Jason Zhang
import torch
from torch import nn
import torch.nn.functional as F
if __name__ == "__main__":
run_code = 0
in_channels = 2
out_channels = 2
kernel_size = 3
w = 9
h = 9
x = torch.ones(1, in_channels, w, h) # input image size
# 1. pytorch method 1.1 x --> image = 1,2,9,9 1.2 kernel --> conv2d -->2,2,3,3 --> 2,2,3,3 --> kernel_size =3x3
# 1.3 compute . if we have out_chanel for the 2 ,and
# each channel has 2 kernel with 3x3, total 2x2--> 4 counts for 3x3
conv_2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same")
print(f"conv_2d.weight.shape={conv_2d.weight.shape}")
print(f"conv_2d.weight={conv_2d.weight}")
conv_2d_pointwise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
print(f"result1.shape=\n{result1.shape}")
print(f"result1=\n{result1}")
zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
starts = torch.unsqueeze(F.pad(torch.ones(1, 1), (1, 1, 1, 1)), 0)
print(f"zeros=\n{zeros}")
print(f"zeros.shape={zeros.shape}")
print(f"starts=\n{starts}")
print(f"starts.shape=\n{starts.shape}")
starts_zeros = torch.unsqueeze(torch.cat((starts, zeros), 0), 0)
zeros_starts = torch.unsqueeze(torch.cat((zeros, starts), 0), 0)
print(f"start_zeros=\n{starts_zeros}")
print(f"start_zeros.shape=\n{starts_zeros.shape}")
print(f"zeros_starts=\n{zeros_starts}")
print(f"zeros_starts.shape=\n{zeros_starts.shape}")
identity_weight = torch.cat((starts_zeros, zeros_starts), 0)
identity_bias = torch.zeros(out_channels)
print(f"identity=\n{identity_weight}")
print(f"identity.shape=\n{identity_weight.shape}")
test_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding="same")
test_conv2d_weight = test_conv2d.weight
test_conv2d_bias = test_conv2d.bias
print(test_conv2d_weight.shape)
print(test_conv2d_bias.shape)
test_conv2d.weight = nn.Parameter(identity_weight)
test_conv2d.bias = nn.Parameter(identity_bias)
input_x = torch.randint(1, 10, (1, 2, 9, 9), dtype=torch.float)
out_y = test_conv2d(input_x)
print(f"input_x=\n{input_x}")
print(f"out_y=\n{out_y}")
check_out = torch.allclose(input_x, out_y)
print(f"input_x is {check_out} same for out_y")
point_wise = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding="same")
point_wise_weight = F.pad(point_wise.weight, (1, 1, 1, 1, 0, 0, 0, 0))
point_wise_bias = point_wise.bias
print(f"point_wise=\n{point_wise_weight}")
print(f"point_wise.shape=\n{point_wise_weight.shape}")
point_wise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
padding="same")
point_wise_conv.weight = nn.Parameter(point_wise_weight)
point_wise_conv.bias = nn.Parameter(point_wise_bias)
point_wise_out = point_wise(input_x)
print(f"point_wise_out=\n{point_wise_out}")
point_3_wise_out = point_wise_conv(input_x)
check_3 = torch.allclose(point_wise_out, point_3_wise_out)
print(f"check_3 is {check_3} same for point_3_wise_out")