参考:Attention-UNet for Pneumothorax Segmentation
以上为 Attention Gate 的原始结构图,可以按照下面的结构图进行理解:
-
输入为 $x$(左边 up_sampling_2d_11)和 $g$(最上 conv2d_126)
-
$x$ 经过一个卷积、$g$ 经过一个卷积,然后两者做个加法
-
之后连续的 ReLU、卷积、Sigmod,得到权重图片,如下图的 activation_19
-
最后将 activation_19 与 $x$ 进行相乘,就完成了整个过程
实现代码:
from keras import Input from keras.layers import Conv2D, Activation, UpSampling2D, Lambda, Dropout, MaxPooling2D, multiply, add from keras import backend as K from keras.models import Model IMG_CHANNEL = 3 def AttnBlock2D(x, g, inter_channel, data_format=‘channels_first‘): theta_x = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(x) phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1], data_format=data_format)(g) f = Activation(‘relu‘)(add([theta_x, phi_g])) psi_f = Conv2D(1, [1, 1], strides=[1, 1], data_format=data_format)(f) rate = Activation(‘sigmoid‘)(psi_f) att_x = multiply([x, rate]) return att_x def attention_up_and_concate(down_layer, layer, data_format=‘channels_first‘): if data_format == ‘channels_first‘: in_channel = down_layer.get_shape().as_list()[1] else: in_channel = down_layer.get_shape().as_list()[3] up = UpSampling2D(size=(2, 2), data_format=data_format)(down_layer) layer = AttnBlock2D(x=layer, g=up, inter_channel=in_channel // 4, data_format=data_format) if data_format == ‘channels_first‘: my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=1)) else: my_concat = Lambda(lambda x: K.concatenate([x[0], x[3]], axis=3)) concate = my_concat([up, layer]) return concate # Attention U-Net def att_unet(img_w, img_h, n_label, data_format=‘channels_first‘): # inputs = (3, 160, 160) inputs = Input((IMG_CHANNEL, img_w, img_h)) x = inputs depth = 4 features = 32 skips = [] # depth = 0, 1, 2, 3 for i in range(depth): # ENCODER x = Conv2D(features, (3, 3), activation=‘relu‘, padding=‘same‘, data_format=data_format)(x) x = Dropout(0.2)(x) x = Conv2D(features, (3, 3), activation=‘relu‘, padding=‘same‘, data_format=data_format)(x) skips.append(x) x = MaxPooling2D((2, 2), data_format=‘channels_first‘)(x) features = features * 2 # BOTTLENECK x = Conv2D(features, (3, 3), activation=‘relu‘, padding=‘same‘, data_format=data_format)(x) x = Dropout(0.2)(x) x = Conv2D(features, (3, 3), activation=‘relu‘, padding=‘same‘, data_format=data_format)(x) # DECODER for i in reversed(range(depth)): features = features // 2 x = attention_up_and_concate(x, skips[i], data_format=data_format) x = Conv2D(features, (3, 3), activation=‘relu‘, padding=‘same‘, data_format=data_format)(x) x = Dropout(0.2)(x) x = Conv2D(features, (3, 3), activation=‘relu‘, padding=‘same‘, data_format=data_format)(x) conv6 = Conv2D(n_label, (1, 1), padding=‘same‘, data_format=data_format)(x) conv7 = Activation(‘sigmoid‘)(conv6) model = Model(inputs=inputs, outputs=conv7) return model IMG_WIDTH = 160 IMG_HEIGHT = 160 model = att_unet(IMG_WIDTH, IMG_HEIGHT, n_label=1) model.summary() from keras.utils.vis_utils import plot_model plot_model(model, to_file=‘Att_U_Net.png‘, show_shapes=True)
输出:
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_11 (InputLayer) [(None, 3, 160, 160) 0 __________________________________________________________________________________________________ conv2d_119 (Conv2D) (None, 32, 160, 160) 896 input_11[0][0] __________________________________________________________________________________________________ dropout_45 (Dropout) (None, 32, 160, 160) 0 conv2d_119[0][0] __________________________________________________________________________________________________ conv2d_120 (Conv2D) (None, 32, 160, 160) 9248 dropout_45[0][0] __________________________________________________________________________________________________ max_pooling2d_32 (MaxPooling2D) (None, 32, 80, 80) 0 conv2d_120[0][0] __________________________________________________________________________________________________ conv2d_121 (Conv2D) (None, 64, 80, 80) 18496 max_pooling2d_32[0][0] __________________________________________________________________________________________________ dropout_46 (Dropout) (None, 64, 80, 80) 0 conv2d_121[0][0] __________________________________________________________________________________________________ conv2d_122 (Conv2D) (None, 64, 80, 80) 36928 dropout_46[0][0] __________________________________________________________________________________________________ max_pooling2d_33 (MaxPooling2D) (None, 64, 40, 40) 0 conv2d_122[0][0] __________________________________________________________________________________________________ conv2d_123 (Conv2D) (None, 128, 40, 40) 73856 max_pooling2d_33[0][0] __________________________________________________________________________________________________ dropout_47 (Dropout) (None, 128, 40, 40) 0 conv2d_123[0][0] __________________________________________________________________________________________________ conv2d_124 (Conv2D) (None, 128, 40, 40) 147584 dropout_47[0][0] __________________________________________________________________________________________________ max_pooling2d_34 (MaxPooling2D) (None, 128, 20, 20) 0 conv2d_124[0][0] __________________________________________________________________________________________________ conv2d_125 (Conv2D) (None, 256, 20, 20) 295168 max_pooling2d_34[0][0] __________________________________________________________________________________________________ dropout_48 (Dropout) (None, 256, 20, 20) 0 conv2d_125[0][0] __________________________________________________________________________________________________ conv2d_126 (Conv2D) (None, 256, 20, 20) 590080 dropout_48[0][0] __________________________________________________________________________________________________ max_pooling2d_35 (MaxPooling2D) (None, 256, 10, 10) 0 conv2d_126[0][0] __________________________________________________________________________________________________ conv2d_127 (Conv2D) (None, 512, 10, 10) 1180160 max_pooling2d_35[0][0] __________________________________________________________________________________________________ dropout_49 (Dropout) (None, 512, 10, 10) 0 conv2d_127[0][0] __________________________________________________________________________________________________ conv2d_128 (Conv2D) (None, 512, 10, 10) 2359808 dropout_49[0][0] __________________________________________________________________________________________________ up_sampling2d_11 (UpSampling2D) (None, 512, 20, 20) 0 conv2d_128[0][0] __________________________________________________________________________________________________ conv2d_129 (Conv2D) (None, 128, 20, 20) 32896 conv2d_126[0][0] __________________________________________________________________________________________________ conv2d_130 (Conv2D) (None, 128, 20, 20) 65664 up_sampling2d_11[0][0] __________________________________________________________________________________________________ add_6 (Add) (None, 128, 20, 20) 0 conv2d_129[0][0] conv2d_130[0][0] __________________________________________________________________________________________________ activation_18 (Activation) (None, 128, 20, 20) 0 add_6[0][0] __________________________________________________________________________________________________ conv2d_131 (Conv2D) (None, 1, 20, 20) 129 activation_18[0][0] __________________________________________________________________________________________________ activation_19 (Activation) (None, 1, 20, 20) 0 conv2d_131[0][0] __________________________________________________________________________________________________ multiply_6 (Multiply) (None, 256, 20, 20) 0 conv2d_126[0][0] activation_19[0][0] __________________________________________________________________________________________________ lambda_5 (Lambda) (None, 768, 20, 20) 0 up_sampling2d_11[0][0] multiply_6[0][0] __________________________________________________________________________________________________ conv2d_132 (Conv2D) (None, 256, 20, 20) 1769728 lambda_5[0][0] __________________________________________________________________________________________________ dropout_50 (Dropout) (None, 256, 20, 20) 0 conv2d_132[0][0] __________________________________________________________________________________________________ conv2d_133 (Conv2D) (None, 256, 20, 20) 590080 dropout_50[0][0] __________________________________________________________________________________________________ up_sampling2d_12 (UpSampling2D) (None, 256, 40, 40) 0 conv2d_133[0][0] __________________________________________________________________________________________________ conv2d_134 (Conv2D) (None, 64, 40, 40) 8256 conv2d_124[0][0] __________________________________________________________________________________________________ conv2d_135 (Conv2D) (None, 64, 40, 40) 16448 up_sampling2d_12[0][0] __________________________________________________________________________________________________ add_7 (Add) (None, 64, 40, 40) 0 conv2d_134[0][0] conv2d_135[0][0] __________________________________________________________________________________________________ activation_20 (Activation) (None, 64, 40, 40) 0 add_7[0][0] __________________________________________________________________________________________________ conv2d_136 (Conv2D) (None, 1, 40, 40) 65 activation_20[0][0] __________________________________________________________________________________________________ activation_21 (Activation) (None, 1, 40, 40) 0 conv2d_136[0][0] __________________________________________________________________________________________________ multiply_7 (Multiply) (None, 128, 40, 40) 0 conv2d_124[0][0] activation_21[0][0] __________________________________________________________________________________________________ lambda_6 (Lambda) (None, 384, 40, 40) 0 up_sampling2d_12[0][0] multiply_7[0][0] __________________________________________________________________________________________________ conv2d_137 (Conv2D) (None, 128, 40, 40) 442496 lambda_6[0][0] __________________________________________________________________________________________________ dropout_51 (Dropout) (None, 128, 40, 40) 0 conv2d_137[0][0] __________________________________________________________________________________________________ conv2d_138 (Conv2D) (None, 128, 40, 40) 147584 dropout_51[0][0] __________________________________________________________________________________________________ up_sampling2d_13 (UpSampling2D) (None, 128, 80, 80) 0 conv2d_138[0][0] __________________________________________________________________________________________________ conv2d_139 (Conv2D) (None, 32, 80, 80) 2080 conv2d_122[0][0] __________________________________________________________________________________________________ conv2d_140 (Conv2D) (None, 32, 80, 80) 4128 up_sampling2d_13[0][0] __________________________________________________________________________________________________ add_8 (Add) (None, 32, 80, 80) 0 conv2d_139[0][0] conv2d_140[0][0] __________________________________________________________________________________________________ activation_22 (Activation) (None, 32, 80, 80) 0 add_8[0][0] __________________________________________________________________________________________________ conv2d_141 (Conv2D) (None, 1, 80, 80) 33 activation_22[0][0] __________________________________________________________________________________________________ activation_23 (Activation) (None, 1, 80, 80) 0 conv2d_141[0][0] __________________________________________________________________________________________________ multiply_8 (Multiply) (None, 64, 80, 80) 0 conv2d_122[0][0] activation_23[0][0] __________________________________________________________________________________________________ lambda_7 (Lambda) (None, 192, 80, 80) 0 up_sampling2d_13[0][0] multiply_8[0][0] __________________________________________________________________________________________________ conv2d_142 (Conv2D) (None, 64, 80, 80) 110656 lambda_7[0][0] __________________________________________________________________________________________________ dropout_52 (Dropout) (None, 64, 80, 80) 0 conv2d_142[0][0] __________________________________________________________________________________________________ conv2d_143 (Conv2D) (None, 64, 80, 80) 36928 dropout_52[0][0] __________________________________________________________________________________________________ up_sampling2d_14 (UpSampling2D) (None, 64, 160, 160) 0 conv2d_143[0][0] __________________________________________________________________________________________________ conv2d_144 (Conv2D) (None, 16, 160, 160) 528 conv2d_120[0][0] __________________________________________________________________________________________________ conv2d_145 (Conv2D) (None, 16, 160, 160) 1040 up_sampling2d_14[0][0] __________________________________________________________________________________________________ add_9 (Add) (None, 16, 160, 160) 0 conv2d_144[0][0] conv2d_145[0][0] __________________________________________________________________________________________________ activation_24 (Activation) (None, 16, 160, 160) 0 add_9[0][0] __________________________________________________________________________________________________ conv2d_146 (Conv2D) (None, 1, 160, 160) 17 activation_24[0][0] __________________________________________________________________________________________________ activation_25 (Activation) (None, 1, 160, 160) 0 conv2d_146[0][0] __________________________________________________________________________________________________ multiply_9 (Multiply) (None, 32, 160, 160) 0 conv2d_120[0][0] activation_25[0][0] __________________________________________________________________________________________________ lambda_8 (Lambda) (None, 96, 160, 160) 0 up_sampling2d_14[0][0] multiply_9[0][0] __________________________________________________________________________________________________ conv2d_147 (Conv2D) (None, 32, 160, 160) 27680 lambda_8[0][0] __________________________________________________________________________________________________ dropout_53 (Dropout) (None, 32, 160, 160) 0 conv2d_147[0][0] __________________________________________________________________________________________________ conv2d_148 (Conv2D) (None, 32, 160, 160) 9248 dropout_53[0][0] __________________________________________________________________________________________________ conv2d_149 (Conv2D) (None, 1, 160, 160) 33 conv2d_148[0][0] __________________________________________________________________________________________________ activation_26 (Activation) (None, 1, 160, 160) 0 conv2d_149[0][0] ================================================================================================== Total params: 7,977,941 Trainable params: 7,977,941 Non-trainable params: 0 __________________________________________________________________________________________________
结构图如下: