UNET建筑物分割轮廓识别

语义分割UNET模型

UNET模型

unet语义分割模型在kaggle竞赛中的一些图像识别任务比较火,比如data-science-bowl-2018airbus-ship-detection。另外它在医学图像上表现也非常好。它简单,高效,易懂,容易构建,而且训练所需的数据集数量也无需特别多。

unet论文中的网络结构长成如下图所示。这个结构比较简单,左边相当于一个Encoder,右边相当于一个Decoder。左边的Encoder主要是提取特征,主要操作是使用size为3的卷积核进行卷积,然后进行maxpooling。右边为Decoder,主要操作是up-conv和3*3的卷积操作。有两个地方需要注意是1. UNET网络进行特征图的copy and crop。2. 在最后的输出层进行size为1的卷积操作。UNet共进行了4次上采样,并在同一个stage使用了skip connection,而不是直接在高级语义特征上进行监督和loss反传,这样就保证了最后恢复出来的特征图融合了更多的low-level的feature,也使得不同scale的feature得到了的融合,从而可以进行多尺度预测和DeepSupervision。4次上采样也使得分割图恢复边缘等信息更加精细。skip-connection联系了输入图像的很多信息,有助于还原降采样所带来的信息损失,在一定程度上,它和残差的操作非常类似。

UNET建筑物分割轮廓识别

总结一下unet网络结构。左边是编码器,作用是提取特征。右边是解码器,通过上采样的方式将结果输出。unet相比FCN网络,unet通过拼接融合特征图,这样做的好处是:深层网络层,有更大的感受野,更关注图像本质的特征,而浅层特征图关注的是纹理特征。因此无论深层,浅层的特征图,都有其作用,通过这种拼接融合,使得网络能够很好地学习到特征。

下图是我使用unet算法对影像图像的建筑物进行识别,数据集大概十几张即可,训练出来的效果还能接受。

UNET建筑物分割轮廓识别

代码实现

本代码是参考这大佬的代码进行修改:keras实现unet模型

model

Backbone使用的是VGG16的网络,分别进行两个卷积操作,记录卷积后操作的特征图,然后进行maxpooling。以此类推进行五次操作。

def VGG16(img_input):
    # Block 1
    # 512,512,3 -> 512,512,64
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block1_conv1')(img_input)
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block1_conv2')(x)
    feat1 = x
    # 512,512,64 -> 256,256,64
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    # Block 2
    # 256,256,64 -> 256,256,128
    x = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block2_conv1')(x)
    x = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block2_conv2')(x)
    feat2 = x
    # 256,256,128 -> 128,128,128
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)


    # Block 3
    # 128,128,128 -> 128,128,256
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block3_conv1')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block3_conv2')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block3_conv3')(x)
    feat3 = x
    # 128,128,256 -> 64,64,256
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

    # Block 4
    # 64,64,256 -> 64,64,512
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block4_conv1')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block4_conv2')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block4_conv3')(x)
    feat4 = x
    # 64,64,512 -> 32,32,512
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

    # Block 5
    # 32,32,512 -> 32,32,512
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block5_conv1')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block5_conv2')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block5_conv3')(x)
    feat5 = x
    return feat1, feat2, feat3, feat4, feat5

以上是encoder特征提取的过程,下面的代码是上采样decoder的过程:

def Unet(input_shape=(256,256,3), num_classes=21):
    inputs = Input(input_shape)
    feat1, feat2, feat3, feat4, feat5 = VGG16(inputs) 
    channels = [64, 128, 256, 512]
    # 32, 32, 512 -> 64, 64, 512
    P5_up = UpSampling2D(size=(2, 2))(feat5)
    # 64, 64, 512 + 64, 64, 512 -> 64, 64, 1024
    P4 = Concatenate(axis=3)([feat4, P5_up])
    # 64, 64, 1024 -> 64, 64, 512
    P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P4)
    P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P4)
    # 64, 64, 512 -> 128, 128, 512
    P4_up = UpSampling2D(size=(2, 2))(P4)
    # 128, 128, 256 + 128, 128, 512 -> 128, 128, 768
    P3 = Concatenate(axis=3)([feat3, P4_up])
    # 128, 128, 768 -> 128, 128, 256
    P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P3)
    P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P3)

    # 128, 128, 256 -> 256, 256, 256
    P3_up = UpSampling2D(size=(2, 2))(P3)
    # 256, 256, 256 + 256, 256, 128 -> 256, 256, 384
    P2 = Concatenate(axis=3)([feat2, P3_up])
    # 256, 256, 384 -> 256, 256, 128
    P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P2)
    P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P2)

    # 256, 256, 128 -> 512, 512, 128
    P2_up = UpSampling2D(size=(2, 2))(P2)
    # 512, 512, 128 + 512, 512, 64 -> 512, 512, 192
    P1 = Concatenate(axis=3)([feat1, P2_up])
    # 512, 512, 192 -> 512, 512, 64
    P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P1)
    P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P1)

    # 512, 512, 64 -> 512, 512, num_classes
    P1 = Conv2D(num_classes, 1, activation="softmax")(P1)

    model = Model(inputs=inputs, outputs=P1)
    return model

至此,unet模型的结构已经实现了。

Train

  • 数据加载部分

    数据记载部分感觉没什么可以说的,主要有以下操作:1. 数据增强,随机对图像进行处理。首先是要resize图像,然后翻转图像,接着distort 图像等。2. 标签需要编码成one-hot的形式。

    
    class Generator(object):
        def __init__(self,batch_size,train_lines,image_size,num_classes,dataset_path):
            self.batch_size     = batch_size
            self.train_lines    = train_lines
            self.train_batches  = len(train_lines)
            self.image_size     = image_size
            self.num_classes    = num_classes
            self.dataset_path   = dataset_path
    
        def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5):
            label = Image.fromarray(np.array(label))
    
            h, w = input_shape
            # resize image
            rand_jit1 = rand(1-jitter,1+jitter)
            rand_jit2 = rand(1-jitter,1+jitter)
            new_ar = w/h * rand_jit1/rand_jit2
    
            scale = rand(0.25, 2)
            if new_ar < 1:
                nh = int(scale*h)
                nw = int(nh*new_ar)
            else:
                nw = int(scale*w)
                nh = int(nw/new_ar)
            image = image.resize((nw,nh), Image.BICUBIC)
            label = label.resize((nw,nh), Image.NEAREST)
            label = label.convert("L")
            
            # flip image or not
            flip = rand()<.5
            if flip: 
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)
            
            # place image
            dx = int(rand(0, w-nw))
            dy = int(rand(0, h-nh))
            new_image = Image.new('RGB', (w,h), (128,128,128))
            new_label = Image.new('L', (w,h), (0))
            new_image.paste(image, (dx, dy))
            new_label.paste(label, (dx, dy))
            image = new_image
            label = new_label
    
            # distort image
            hue = rand(-hue, hue)
            sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat)
            val = rand(1, val) if rand()<.5 else 1/rand(1, val)
            x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
            x[..., 0] += hue*360
            x[..., 0][x[..., 0]>1] -= 1
            x[..., 0][x[..., 0]<0] += 1
            x[..., 1] *= sat
            x[..., 2] *= val
            x[x[:,:, 0]>360, 0] = 360
            x[:, :, 1:][x[:, :, 1:]>1] = 1
            x[x<0] = 0
            image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
            return image_data,label
            
        def generate(self, random_data = True):
            i = 0
            length = len(self.train_lines)
            inputs = []
            targets = []
            while True:
                if i == 0:
                    shuffle(self.train_lines)
                annotation_line = self.train_lines[i]
                name = annotation_line.split()[0]
    
                # 从文件中读取图像
                jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "JPEGImages"), name + ".jpg"))
                png = Image.open(os.path.join(os.path.join(self.dataset_path, "labels"), name + ".png"))
    
                if random_data:
                    jpg, png = self.get_random_data(jpg,png,(int(self.image_size[1]),int(self.image_size[0])))
                else:
                    jpg, png = letterbox_image(jpg, png, (int(self.image_size[1]),int(self.image_size[0])))
                
                inputs.append(np.array(jpg)/255)
                
                png = np.array(png)
                png[png >= self.num_classes] = self.num_classes
                seg_labels = np.eye(self.num_classes+1)[png.reshape([-1])]
                seg_labels = seg_labels.reshape((int(self.image_size[1]),int(self.image_size[0]),self.num_classes+1))
                
                targets.append(seg_labels)
                i = (i + 1) % length
                if len(targets) == self.batch_size:
                    tmp_inp = np.array(inputs)
                    tmp_targets = np.array(targets)
                    inputs = []
                    targets = []
                    yield tmp_inp, tmp_targets
    
  • CE/CE_LOSS

    def dice_loss_with_CE(beta=1, smooth = 1e-5):
       def _dice_loss_with_CE(y_true, y_pred):
           y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
    
           CE_loss = - y_true[...,:-1] * K.log(y_pred)
           CE_loss = K.mean(K.sum(CE_loss, axis = -1))
    
           tp = K.sum(y_true[...,:-1] * y_pred, axis=[0,1,2])
           fp = K.sum(y_pred         , axis=[0,1,2]) - tp
           fn = K.sum(y_true[...,:-1], axis=[0,1,2]) - tp
    
           score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
           score = tf.reduce_mean(score)
           dice_loss = 1 - score
           # dice_loss = tf.Print(dice_loss, [dice_loss, CE_loss])
           return CE_loss + dice_loss
       return _dice_loss_with_CE
    
    def CE():
       def _CE(y_true, y_pred):
           y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
    
           CE_loss = - y_true[...,:-1] * K.log(y_pred)
           CE_loss = K.mean(K.sum(CE_loss, axis = -1))
           # dice_loss = tf.Print(CE_loss, [CE_loss])
           return CE_loss
       return _CE
    

    详情请参考:语义分割损失函数总结

  • 记录训练损失

    class LossHistory(keras.callbacks.Callback):
        def __init__(self, log_dir):
            import datetime
            curr_time = datetime.datetime.now()
            time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
            self.log_dir    = log_dir
            self.time_str   = time_str
            self.save_path  = os.path.join(self.log_dir, "loss_" + str(self.time_str))  
            self.losses     = []
            self.val_loss   = []
            
            os.makedirs(self.save_path)
    
        def on_epoch_end(self, batch, logs={}):
            self.losses.append(logs.get('loss'))
            self.val_loss.append(logs.get('val_loss'))
            with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f:
                f.write(str(logs.get('loss')))
                f.write("\n")
            with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f:
                f.write(str(logs.get('val_loss')))
                f.write("\n")
            # self.loss_plot()
    
        def loss_plot(self):
            iters = range(len(self.losses))
    
            plt.figure()
            plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
            plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
            try:
                if len(self.losses) < 25:
                    num = 5
                else:
                    num = 15
                
                plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
                plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
            except:
                pass
    
            plt.grid(True)
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('A Loss Curve')
            plt.legend(loc="upper right")
    
            plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png"))
    
            plt.cla()
            plt.close("all")
    
    

Predict

训练的过程中只保留了权重,因此需要先实现网络,然后加载权重。模型代码如下:

class Unet(object):

    #---------------------------------------------------#
    #   初始化UNET
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        _defaults = {
            "model_path"        : kwargs['model'],
            "model_image_size"  : kwargs['model_image_size'],
            "num_classes"       : kwargs['num_classes']
        }
        self.__dict__.update(_defaults)
        self.generate()

    #---------------------------------------------------#
    #   载入模型
    #---------------------------------------------------#
    def generate(self):
        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        self.model = unet(self.model_image_size, self.num_classes)

        self.model.load_weights(self.model_path)
        print('{} model loaded.'.format(self.model_path))
        
        if self.num_classes <= 21:
            self.colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 
                    (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 
                    (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)]
        else:
            # 画框设置不同的颜色
            hsv_tuples = [(x / len(self.class_names), 1., 1.)
                        for x in range(len(self.class_names))]
            self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
            self.colors = list(
                map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                    self.colors))

    def letterbox_image(self ,image, size):
        image = image.convert("RGB")
        iw, ih = image.size
        w, h = size
        scale = min(w/iw, h/ih)
        nw = int(iw*scale)
        nh = int(ih*scale)

        image = image.resize((nw,nh), Image.BICUBIC)
        new_image = Image.new('RGB', size, (128,128,128))
        new_image.paste(image, ((w-nw)//2, (h-nh)//2))
        return new_image,nw,nh

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #---------------------------------------------------------#
        image = image.convert('RGB')
        
        #---------------------------------------------------#
        #   对输入图像进行一个备份,后面用于绘图
        #---------------------------------------------------#
        old_img = copy.deepcopy(image)
        orininal_h = np.array(image).shape[0]
        orininal_w = np.array(image).shape[1]

        #---------------------------------------------------#
        #   进行不失真的resize,添加灰条,进行图像归一化
        #---------------------------------------------------#
        img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
        img = np.asarray([np.array(img)/255])
        pr = self.model.predict(img)[0]
        #---------------------------------------------------#
        #   取出每一个像素点的种类
        #---------------------------------------------------#
        pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
        #--------------------------------------#
        #   将灰条部分截取掉
        #--------------------------------------#
        pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]

        #------------------------------------------------#
        #   创建一副新图,并根据每个像素点的种类赋予颜色
        #------------------------------------------------#
        seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
        for c in range(self.num_classes):
            seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8')
            seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8')
            seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8')

        image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h), Image.NEAREST)
        blend_image = Image.blend(old_img,image,0.7)

        return image, blend_image

    def get_FPS(self, image, test_interval):
        orininal_h = np.array(image).shape[0]
        orininal_w = np.array(image).shape[1]

        img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
        img = np.asarray([np.array(img)/255])

        pr = self.model.predict(img)[0]
        pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
        pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
        
        image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST)

        t1 = time.time()
        for _ in range(test_interval):
            pr = self.model.predict(img)[0]
            pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
            pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
            image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST)
            
        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time
        

参考

  1. UNET论文

  2. keras实现unet模型

上一篇:JAVA的基本类型的包装类


下一篇:java基础之128陷阱