DeepLab-V3+语义分割神经网络的keras 版本实现

DeepLab-V3+语义分割神经网络的keras 版本实现

DeepLab-V3+语义分割神经网络的keras 版本实现

网络结构

Deeplab系列网络模型从ResNet残差模块发展而来,在此基础上融合了空洞卷积(Atrous Convolution)实现。与Deeplab v3相比,Deeplab v3+为了融合多尺度信息,引入语义分割常用的编码器-解码器结构[25][26],编码器提供丰富的语义信息,而解码器恢复精细的物体边缘,从而融合底层特征和高层特征,对于分割边界准确度有进一步提升,同时引入可任意控制编码器提取特征的分辨率,通过空洞卷积平衡精度和耗时。

DeepLab v3+模型的结构如图 3-6所示,其编码器的主体是带有空洞卷积的DCNN。空洞卷积是DeepLab模型的关键之一,由于空洞提取特征点时会跨像素,使得其可以在不改变特征图大小的同时加大感受野,让每个卷积输出包含的信息范围变大,对于编码器而言有利于提取更有效的特征,提取多尺度信息。

同时,DeepLab v3+模型采用ASPP(Atrous Spatial Pyramid Pooling,空间金字塔池化)模块,通过使用不同的感受野和上采样,进一步提取多尺度特征。

完整代码

添加引用库

from keras.preprocessing import image
from keras.models import Model, load_model, Sequential
from keras import backend as K
from keras.utils import np_utils
from keras.preprocessing.image import img_to_array
from sklearn.preprocessing import LabelEncoder
from keras import metrics
from keras.losses import binary_crossentropy

import matplotlib.pyplot as plt
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Reshape, Permute, Activation, Input
from keras.layers import DepthwiseConv2D, ZeroPadding2D, GlobalAveragePooling2D, Lambda, Concatenate, Dropout, Conv2DTranspose
from keras import layers
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger
from keras.layers.merge import concatenate
from PIL import Image
from keras.optimizers import Adam
import tensorflow as tf
from keras.applications.resnet50 import ResNet50

import matplotlib as mpl
import seaborn as sns

import os
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import shutil
import pandas as pd
import time
from tqdm import *

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
seed = 7
np.random.seed(seed)

数据加载

if not os.path.exists('./edge_clip/'):
  !unzip /content/drive/MyDrive/edge_clip.zip -d ./
img_w = 512
img_h = 512
n_label = 1

classes = [0., 1.]

labelencoder = LabelEncoder()
labelencoder.fit(classes)

image_sets = os.listdir('/content/edge_clip/train/src/')


def load_img(path, grayscale=False):
    if grayscale:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) / 255.0
    else:
        img = cv2.imread(path)
        img = np.array(img, dtype="float") / 255.0
    return img


filepath = '/content/edge_clip/train/'
def get_train_val(val_rate=0.2):
    train_url = []
    train_set = []
    val_set = []
    for pic in os.listdir(filepath + 'src/'):
        train_url.append(pic)

    random.seed(14)
    random.shuffle(train_url)
    total_num = len(train_url)
    val_num = int(val_rate * total_num)
    for i in range(len(train_url)):
        if i < val_num:
            val_set.append(train_url[i])
        else:
            train_set.append(train_url[i])
    return train_set, val_set


# data for training
def generateData(batch_size, data=[]):
  # print 'generateData...'
  while True:
    train_data = []
    train_label = []
    batch = 0
    for i in (range(len(data))):
        url = data[i]
        batch += 1
        img = load_img(filepath + 'src/' + url)
        img = img_to_array(img)
        train_data.append(img)
        label = load_img(filepath + 'label/' + url, grayscale=True)
        label = img_to_array(label)
        train_label.append(label)
        if batch % batch_size == 0:
            # print 'get enough bacth!\n'
            train_data = np.array(train_data)
            train_label = np.array(train_label)
            yield (train_data, train_label)
            train_data = []
            train_label = []
            batch = 0

            # data for validation


def generateValidData(batch_size, data=[]):
  # print 'generateValidData...'
  while True:
    valid_data = []
    valid_label = []
    batch = 0
    for i in (range(len(data))):
        url = data[i]
        batch += 1
        img = load_img(filepath + 'src/' + url)
        img = img_to_array(img)
        valid_data.append(img)
        label = load_img(filepath + 'label/' + url, grayscale=True)
        label = img_to_array(label)
        valid_label.append(label)
        if batch % batch_size == 0:
            valid_data = np.array(valid_data)
            valid_label = np.array(valid_label)
            yield (valid_data, valid_label)
            valid_data = []
            valid_label = []
            batch = 0
train_set, val_set = get_train_val()
len(train_set), len(val_set)
(8320, 2080)

构建模型

def Upsample(tensor, size):
    '''bilinear upsampling'''
    name = tensor.name.split('/')[0] + '_upsample'

    def bilinear_upsample(x, size):
        resized = tf.image.resize(
            images=x, size=size)
        return resized
    y = Lambda(lambda x: bilinear_upsample(x, size),
               output_shape=size, name=name)(tensor)
    return y


def ASPP(tensor):
    '''atrous spatial pyramid pooling'''
    dims = K.int_shape(tensor)

    y_pool = AveragePooling2D(pool_size=(
        dims[1], dims[2]), name='average_pooling')(tensor)
    y_pool = Conv2D(filters=256, kernel_size=1, padding='same',
                    kernel_initializer='he_normal', name='pool_1x1conv2d', use_bias=False)(y_pool)
    y_pool = BatchNormalization(name=f'bn_1')(y_pool)
    y_pool = Activation('relu', name=f'relu_1')(y_pool)

    # y_pool = Upsample(tensor=y_pool, size=[dims[1], dims[2]])
    y_pool = Conv2DTranspose(filters=256, kernel_size=(2, 2), 
                  kernel_initializer='he_normal', dilation_rate=512 // 16 - 1)(y_pool)

    y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same',
                 kernel_initializer='he_normal', name='ASPP_conv2d_d1', use_bias=False)(tensor)
    y_1 = BatchNormalization(name=f'bn_2')(y_1)
    y_1 = Activation('relu', name=f'relu_2')(y_1)

    y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same',
                 kernel_initializer='he_normal', name='ASPP_conv2d_d6', use_bias=False)(tensor)
    y_6 = BatchNormalization(name=f'bn_3')(y_6)
    y_6 = Activation('relu', name=f'relu_3')(y_6)

    y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same',
                  kernel_initializer='he_normal', name='ASPP_conv2d_d12', use_bias=False)(tensor)
    y_12 = BatchNormalization(name=f'bn_4')(y_12)
    y_12 = Activation('relu', name=f'relu_4')(y_12)

    y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same',
                  kernel_initializer='he_normal', name='ASPP_conv2d_d18', use_bias=False)(tensor)
    y_18 = BatchNormalization(name=f'bn_5')(y_18)
    y_18 = Activation('relu', name=f'relu_5')(y_18)

    y = concatenate([y_pool, y_1, y_6, y_12, y_18], name='ASPP_concat')

    y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same',
               kernel_initializer='he_normal', name='ASPP_conv2d_final', use_bias=False)(y)
    y = BatchNormalization(name=f'bn_final')(y)
    y = Activation('relu', name=f'relu_final')(y)
    return y


def DeepLabV3Plus(img_height=512, img_width=512, nclasses=1):
    print('*** Building DeepLabv3Plus Network ***')

    base_model = ResNet50(input_shape=(
        img_height, img_width, 3), weights='imagenet', include_top=False)
    
    image_features = base_model.get_layer('conv4_block6_out').output
    x_a = ASPP(image_features)
    # x_a = Upsample(tensor=x_a, size=[img_height // 4, img_width // 4])
    x_a = Conv2DTranspose(filters=256, kernel_size=(2, 2), 
                  kernel_initializer='he_normal', dilation_rate=img_height // 16 * 3)(x_a)

    x_b = base_model.get_layer('conv2_block3_out').output
    x_b = Conv2D(filters=48, kernel_size=1, padding='same',
                 kernel_initializer='he_normal', name='low_level_projection', use_bias=False)(x_b)
    x_b = BatchNormalization(name=f'bn_low_level_projection')(x_b)
    x_b = Activation('relu', name='low_level_activation')(x_b)

    x = concatenate([x_a, x_b], name='decoder_concat')

    x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',
               kernel_initializer='he_normal', name='decoder_conv2d_1', use_bias=False)(x)
    x = BatchNormalization(name=f'bn_decoder_1')(x)
    x = Activation('relu', name='activation_decoder_1')(x)

    x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',
               kernel_initializer='he_normal', name='decoder_conv2d_2', use_bias=False)(x)
    x = BatchNormalization(name=f'bn_decoder_2')(x)
    x = Activation('relu', name='activation_decoder_2')(x)
    # x = Upsample(x, [img_height, img_width])
    x = Conv2DTranspose(filters=256, kernel_size=(2, 2), 
                        kernel_initializer='he_normal', dilation_rate=img_height // 4 * 3)(x)

    x = Conv2D(nclasses, (1, 1), name='output_layer')(x)
    x = Activation('sigmoid')(x) 
    '''
    x = Activation('softmax')(x) 
    tf.losses.SparseCategoricalCrossentropy(from_logits=True)
    Args:
        from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
        we assume that `y_pred` encodes a probability distribution.
    '''     
    model = Model(inputs=base_model.input, outputs=x, name='DeepLabV3_Plus')
    print(f'*** Output_Shape => {model.output_shape} ***')
    return model
model = DeepLabV3Plus(nclasses=1)
*** Building DeepLabv3Plus Network ***
*** Output_Shape => (None, 512, 512, 1) ***
def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=[1,2])
    union = K.sum(y_true, axis=[1,2]) + K.sum(y_pred, axis=[1,2])
    return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)

def dice_coef_loss(y_true, y_pred):
	1 - dice_coef(y_true, y_pred, smooth=1)

def bce_logdice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) - K.log(1. - dice_loss(y_true, y_pred))

model.compile(optimizer=Adam(lr=1e-4), loss=['binary_crossentropy', bce_logdice_loss], metrics=['accuracy'])

模型训练

EPOCHS = 5
BS = 4

model_path = '/content/drive/MyDrive/models/deeplab_v3.h5'

## callback策略
# 保存训练日志
csvlogger = CSVLogger('/content/drive/MyDrive/training_deeplab_v3.csv', separator=',', append=True)
# 学习率衰减策略
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1) # 学习率衰减策略
# 断点训练,有利于恢复和保存模型
checkpoint = ModelCheckpoint(model_path, monitor='val_accuracy', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)
# early stopping策略
early_stopping = EarlyStopping(monitor='val_accuracy', patience=4, verbose=1, mode='auto')


train_numb = len(train_set)
valid_numb = len(val_set)
print("the number of train data is", train_numb)
print("the number of val data is", valid_numb)
H = model.fit(
    generateData(BS, train_set),
    steps_per_epoch=train_numb // BS,
    epochs=EPOCHS, verbose=1,
    validation_data=generateData(BS, val_set),
    validation_steps=valid_numb // BS,
    callbacks=[checkpoint, early_stopping, reduce_lr, csvlogger],
    max_queue_size=BS
    )
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.
the number of train data is 8320
the number of val data is 2080
Epoch 1/5
  75/2080 [>.............................] - ETA: 48:22 - loss: 0.6673 - accuracy: 0.7521

思考总结

  1. 模型权重保存:通常小模型会保存模型的结构和权重,大模型只保存模型的权重,此处由于上采样中存在lambda层,所以无法保存模型结构,可以保存权重层,定义模型结构后再加载权重;
  2. 训练时间:Deeplab-V3+的训练时间如此之长,比unet训练的时间长10倍左右,只能猜测上采样和空洞卷积比较耗时???可能还需要做个验证,朋友们知道的可以告诉我一下;
上一篇:从零开始学keras之过拟合与欠拟合


下一篇:VGGnet论文解读及代码实现