【pytorch】unet网络结构分析及代码实现

原始论文

U-Net: Convolutional Networks for Biomedical Image Segmentation:点击查看

网络结构

【pytorch】unet网络结构分析及代码实现
【pytorch】unet网络结构分析及代码实现

注意事项

  1. 论文中进行卷积操作的时候没有用padding,导致卷积后图片尺寸变小。推荐可能是当年padding操作并不流行。我们这里复现的时候用了padding,保持卷积后图片尺寸不变。输入给网络是什么尺寸的图像,那么输出也将是一样的尺寸。比如输入64 x 64的图像,那么输出也将是64 x 64
  2. 论文中并没有用到Batch Normalization。推测是当时需要作医学图像分割的数据集很小,不需要用。我们这里复现的时候加上。
  3. 论文中提到的跳层连接,推测应该是使用torch.cat()进行通道数合并。

实现思路

首先将网络结构中出现次数较多的两个蓝色箭头Conv+Relu)进行封装。
【pytorch】unet网络结构分析及代码实现

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.double_conv(x)

然后将整个网络结构分为左、中、右三部分,具体划分方式如下:
【pytorch】unet网络结构分析及代码实现
左边由4个下采样(Pooling)和4个双卷积组成;中间一个双卷积;右边是4个上采样(反卷积)和4个双卷积,最后接一个1 x 1的卷积输出。

完整代码

import torch
import torch.nn as nn


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # left
        self.left_conv_1 = DoubleConv(3, 64)
        self.down_1 = nn.MaxPool2d(2, 2)

        self.left_conv_2 = DoubleConv(64, 128)
        self.down_2 = nn.MaxPool2d(2, 2)

        self.left_conv_3 = DoubleConv(128, 256)
        self.down_3 = nn.MaxPool2d(2, 2)

        self.left_conv_4 = DoubleConv(256, 512)
        self.down_4 = nn.MaxPool2d(2, 2)

        # center
        self.center_conv = DoubleConv(512, 1024)

        # right
        self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.right_conv_1 = DoubleConv(1024, 512)

        self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.right_conv_2 = DoubleConv(512, 256)

        self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.right_conv_3 = DoubleConv(256, 128)

        self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.right_conv_4 = DoubleConv(128, 64)

        # output
        self.output = nn.Conv2d(64, 3, 1, 1, 0)

    def forward(self, x):
        # left
        x1 = self.left_conv_1(x)
        x1_down = self.down_1(x1)

        x2 = self.left_conv_2(x1_down)
        x2_down = self.down_2(x2)

        x3 = self.left_conv_3(x2_down)
        x3_down = self.down_3(x3)

        x4 = self.left_conv_4(x3_down)
        x4_down = self.down_4(x4)

        # center
        x5 = self.center_conv(x4_down)

        # right
        x6_up = self.up_1(x5)
        temp = torch.cat((x6_up, x4), dim=1)
        x6 = self.right_conv_1(temp)

        x7_up = self.up_2(x6)
        temp = torch.cat((x7_up, x3), dim=1)
        x7 = self.right_conv_2(temp)

        x8_up = self.up_3(x7)
        temp = torch.cat((x8_up, x2), dim=1)
        x8 = self.right_conv_3(temp)

        x9_up = self.up_4(x8)
        temp = torch.cat((x9_up, x1), dim=1)
        x9 = self.right_conv_4(temp)

        # output
        output = self.output(x9)

        return output

测试一下

如果代码实现如果任何问题,那么网络的输出维度输入维度应该是一样的。

if __name__ == "__main__":
    a = torch.rand(10, 3, 32, 32)
    model = UNet()
    b = model(a)
    print(b.size())  # torch.Size([10, 3, 32, 32])

引用参考

https://zhuanlan.zhihu.com/p/87593567
https://github.com/milesial/Pytorch-UNet

上一篇:Pytorch计算数据集均值和标准差


下一篇:基础分类网络之ResNet