代码:
from keras import Model, layers from keras.layers import Input, Conv2D, BatchNormalization, Activation, Reshape, MaxPooling2D, UpSampling2D def segnet(pretrained_weights=None, input_size=(512, 512, 3), classNum=2, learning_rate=1e-5): inputs = Input(input_size) #encode #第一层 64,64 conv1 = BatchNormalization()( Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)) conv1 = BatchNormalization()( Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) #第二层 128,128 conv2 = BatchNormalization()( Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)) conv2 = BatchNormalization()( Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) #第三层 256,256,256 conv3 = BatchNormalization()( Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)) conv3 = BatchNormalization()( Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)) conv3 = BatchNormalization()( Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)) pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) #第四层 512,512,512 conv4 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)) conv4 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)) conv4 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)) pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) #第五层 512,512,512 conv5 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)) conv5 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)) conv5 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)) pool5 = MaxPooling2D(pool_size=(2, 2))(conv5) #decode #上采样 up1 = UpSampling2D(size=(2, 2))(pool5) #第六层 512,512,512 conv6 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up1)) conv6 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)) conv6 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)) up2 = UpSampling2D(size=(2, 2))(conv6) #第七层 512,512,512 conv7 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up2)) conv7 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)) conv7 = BatchNormalization()( Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)) up3 = UpSampling2D(size=(2, 2))(conv7) #第八层 256,256,256 conv8 = BatchNormalization()( Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up3)) conv8 = BatchNormalization()( Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)) conv8 = BatchNormalization()( Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)) up4 = UpSampling2D(size=(2, 2))(conv8) # 第八层 256,256,256 conv9 = BatchNormalization()( Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up4)) conv9 = BatchNormalization()( Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)) up5 = UpSampling2D(size=(2, 2))(conv9) #第九层 64,64 conv10 = BatchNormalization()( Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up5)) conv10 = BatchNormalization()( Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)) # softmax输出层 conv11 = Conv2D(1, 1, padding='same', activation='sigmoid')(conv10) model = Model(inputs=inputs, outputs=conv11) return model