原始论文
U-Net: Convolutional Networks for Biomedical Image Segmentation:点击查看
网络结构
注意事项
- 论文中进行卷积操作的时候没有用
padding
,导致卷积后图片尺寸变小。推荐可能是当年padding操作并不流行。我们这里复现的时候用了padding
,保持卷积后图片尺寸不变。输入给网络是什么尺寸的图像,那么输出也将是一样的尺寸。比如输入64 x 64
的图像,那么输出也将是64 x 64
。 - 论文中并没有用到
Batch Normalization
。推测是当时需要作医学图像分割的数据集很小,不需要用。我们这里复现的时候加上。 - 论文中提到的跳层连接,推测应该是使用
torch.cat()
进行通道数合并。
实现思路
首先将网络结构中出现次数较多的两个蓝色箭头(Conv+Relu)进行封装。
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
然后将整个网络结构分为左、中、右三部分,具体划分方式如下:
左边由4个下采样(Pooling)和4个双卷积组成;中间一个双卷积;右边是4个上采样(反卷积)和4个双卷积,最后接一个1 x 1
的卷积输出。
完整代码
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self):
super().__init__()
# left
self.left_conv_1 = DoubleConv(3, 64)
self.down_1 = nn.MaxPool2d(2, 2)
self.left_conv_2 = DoubleConv(64, 128)
self.down_2 = nn.MaxPool2d(2, 2)
self.left_conv_3 = DoubleConv(128, 256)
self.down_3 = nn.MaxPool2d(2, 2)
self.left_conv_4 = DoubleConv(256, 512)
self.down_4 = nn.MaxPool2d(2, 2)
# center
self.center_conv = DoubleConv(512, 1024)
# right
self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.right_conv_1 = DoubleConv(1024, 512)
self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2)
self.right_conv_2 = DoubleConv(512, 256)
self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2)
self.right_conv_3 = DoubleConv(256, 128)
self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.right_conv_4 = DoubleConv(128, 64)
# output
self.output = nn.Conv2d(64, 3, 1, 1, 0)
def forward(self, x):
# left
x1 = self.left_conv_1(x)
x1_down = self.down_1(x1)
x2 = self.left_conv_2(x1_down)
x2_down = self.down_2(x2)
x3 = self.left_conv_3(x2_down)
x3_down = self.down_3(x3)
x4 = self.left_conv_4(x3_down)
x4_down = self.down_4(x4)
# center
x5 = self.center_conv(x4_down)
# right
x6_up = self.up_1(x5)
temp = torch.cat((x6_up, x4), dim=1)
x6 = self.right_conv_1(temp)
x7_up = self.up_2(x6)
temp = torch.cat((x7_up, x3), dim=1)
x7 = self.right_conv_2(temp)
x8_up = self.up_3(x7)
temp = torch.cat((x8_up, x2), dim=1)
x8 = self.right_conv_3(temp)
x9_up = self.up_4(x8)
temp = torch.cat((x9_up, x1), dim=1)
x9 = self.right_conv_4(temp)
# output
output = self.output(x9)
return output
测试一下
如果代码实现如果任何问题,那么网络的输出维度和输入维度应该是一样的。
if __name__ == "__main__":
a = torch.rand(10, 3, 32, 32)
model = UNet()
b = model(a)
print(b.size()) # torch.Size([10, 3, 32, 32])
引用参考
https://zhuanlan.zhihu.com/p/87593567
https://github.com/milesial/Pytorch-UNet