语义分割UNET模型
UNET模型
unet语义分割模型在kaggle竞赛中的一些图像识别任务比较火,比如data-science-bowl-2018,airbus-ship-detection。另外它在医学图像上表现也非常好。它简单,高效,易懂,容易构建,而且训练所需的数据集数量也无需特别多。
unet论文中的网络结构长成如下图所示。这个结构比较简单,左边相当于一个Encoder,右边相当于一个Decoder。左边的Encoder主要是提取特征,主要操作是使用size为3的卷积核进行卷积,然后进行maxpooling。右边为Decoder,主要操作是up-conv和3*3的卷积操作。有两个地方需要注意是1. UNET网络进行特征图的copy and crop。2. 在最后的输出层进行size为1的卷积操作。UNet共进行了4次上采样,并在同一个stage使用了skip connection,而不是直接在高级语义特征上进行监督和loss反传,这样就保证了最后恢复出来的特征图融合了更多的low-level的feature,也使得不同scale的feature得到了的融合,从而可以进行多尺度预测和DeepSupervision。4次上采样也使得分割图恢复边缘等信息更加精细。skip-connection联系了输入图像的很多信息,有助于还原降采样所带来的信息损失,在一定程度上,它和残差的操作非常类似。
总结一下unet网络结构。左边是编码器,作用是提取特征。右边是解码器,通过上采样的方式将结果输出。unet相比FCN网络,unet通过拼接融合特征图,这样做的好处是:深层网络层,有更大的感受野,更关注图像本质的特征,而浅层特征图关注的是纹理特征。因此无论深层,浅层的特征图,都有其作用,通过这种拼接融合,使得网络能够很好地学习到特征。
下图是我使用unet算法对影像图像的建筑物进行识别,数据集大概十几张即可,训练出来的效果还能接受。
代码实现
本代码是参考这大佬的代码进行修改:keras实现unet模型
model
Backbone使用的是VGG16的网络,分别进行两个卷积操作,记录卷积后操作的特征图,然后进行maxpooling。以此类推进行五次操作。
def VGG16(img_input): # Block 1 # 512,512,3 -> 512,512,64 x = layers.Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block1_conv1')(img_input) x = layers.Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block1_conv2')(x) feat1 = x # 512,512,64 -> 256,256,64 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) # Block 2 # 256,256,64 -> 256,256,128 x = layers.Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block2_conv1')(x) x = layers.Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block2_conv2')(x) feat2 = x # 256,256,128 -> 128,128,128 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) # Block 3 # 128,128,128 -> 128,128,256 x = layers.Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block3_conv1')(x) x = layers.Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block3_conv2')(x) x = layers.Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block3_conv3')(x) feat3 = x # 128,128,256 -> 64,64,256 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) # Block 4 # 64,64,256 -> 64,64,512 x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block4_conv1')(x) x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block4_conv2')(x) x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block4_conv3')(x) feat4 = x # 64,64,512 -> 32,32,512 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) # Block 5 # 32,32,512 -> 32,32,512 x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block5_conv1')(x) x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block5_conv2')(x) x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02), name='block5_conv3')(x) feat5 = x return feat1, feat2, feat3, feat4, feat5
以上是encoder特征提取的过程,下面的代码是上采样decoder的过程:
def Unet(input_shape=(256,256,3), num_classes=21): inputs = Input(input_shape) feat1, feat2, feat3, feat4, feat5 = VGG16(inputs) channels = [64, 128, 256, 512] # 32, 32, 512 -> 64, 64, 512 P5_up = UpSampling2D(size=(2, 2))(feat5) # 64, 64, 512 + 64, 64, 512 -> 64, 64, 1024 P4 = Concatenate(axis=3)([feat4, P5_up]) # 64, 64, 1024 -> 64, 64, 512 P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P4) P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P4) # 64, 64, 512 -> 128, 128, 512 P4_up = UpSampling2D(size=(2, 2))(P4) # 128, 128, 256 + 128, 128, 512 -> 128, 128, 768 P3 = Concatenate(axis=3)([feat3, P4_up]) # 128, 128, 768 -> 128, 128, 256 P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P3) P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P3) # 128, 128, 256 -> 256, 256, 256 P3_up = UpSampling2D(size=(2, 2))(P3) # 256, 256, 256 + 256, 256, 128 -> 256, 256, 384 P2 = Concatenate(axis=3)([feat2, P3_up]) # 256, 256, 384 -> 256, 256, 128 P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P2) P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P2) # 256, 256, 128 -> 512, 512, 128 P2_up = UpSampling2D(size=(2, 2))(P2) # 512, 512, 128 + 512, 512, 64 -> 512, 512, 192 P1 = Concatenate(axis=3)([feat1, P2_up]) # 512, 512, 192 -> 512, 512, 64 P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P1) P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P1) # 512, 512, 64 -> 512, 512, num_classes P1 = Conv2D(num_classes, 1, activation="softmax")(P1) model = Model(inputs=inputs, outputs=P1) return model
至此,unet模型的结构已经实现了。
Train
-
数据加载部分
数据记载部分感觉没什么可以说的,主要有以下操作:1. 数据增强,随机对图像进行处理。首先是要resize图像,然后翻转图像,接着distort 图像等。2. 标签需要编码成one-hot的形式。
class Generator(object): def __init__(self,batch_size,train_lines,image_size,num_classes,dataset_path): self.batch_size = batch_size self.train_lines = train_lines self.train_batches = len(train_lines) self.image_size = image_size self.num_classes = num_classes self.dataset_path = dataset_path def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5): label = Image.fromarray(np.array(label)) h, w = input_shape # resize image rand_jit1 = rand(1-jitter,1+jitter) rand_jit2 = rand(1-jitter,1+jitter) new_ar = w/h * rand_jit1/rand_jit2 scale = rand(0.25, 2) if new_ar < 1: nh = int(scale*h) nw = int(nh*new_ar) else: nw = int(scale*w) nh = int(nw/new_ar) image = image.resize((nw,nh), Image.BICUBIC) label = label.resize((nw,nh), Image.NEAREST) label = label.convert("L") # flip image or not flip = rand()<.5 if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) label = label.transpose(Image.FLIP_LEFT_RIGHT) # place image dx = int(rand(0, w-nw)) dy = int(rand(0, h-nh)) new_image = Image.new('RGB', (w,h), (128,128,128)) new_label = Image.new('L', (w,h), (0)) new_image.paste(image, (dx, dy)) new_label.paste(label, (dx, dy)) image = new_image label = new_label # distort image hue = rand(-hue, hue) sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) val = rand(1, val) if rand()<.5 else 1/rand(1, val) x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) x[..., 0] += hue*360 x[..., 0][x[..., 0]>1] -= 1 x[..., 0][x[..., 0]<0] += 1 x[..., 1] *= sat x[..., 2] *= val x[x[:,:, 0]>360, 0] = 360 x[:, :, 1:][x[:, :, 1:]>1] = 1 x[x<0] = 0 image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 return image_data,label def generate(self, random_data = True): i = 0 length = len(self.train_lines) inputs = [] targets = [] while True: if i == 0: shuffle(self.train_lines) annotation_line = self.train_lines[i] name = annotation_line.split()[0] # 从文件中读取图像 jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "JPEGImages"), name + ".jpg")) png = Image.open(os.path.join(os.path.join(self.dataset_path, "labels"), name + ".png")) if random_data: jpg, png = self.get_random_data(jpg,png,(int(self.image_size[1]),int(self.image_size[0]))) else: jpg, png = letterbox_image(jpg, png, (int(self.image_size[1]),int(self.image_size[0]))) inputs.append(np.array(jpg)/255) png = np.array(png) png[png >= self.num_classes] = self.num_classes seg_labels = np.eye(self.num_classes+1)[png.reshape([-1])] seg_labels = seg_labels.reshape((int(self.image_size[1]),int(self.image_size[0]),self.num_classes+1)) targets.append(seg_labels) i = (i + 1) % length if len(targets) == self.batch_size: tmp_inp = np.array(inputs) tmp_targets = np.array(targets) inputs = [] targets = [] yield tmp_inp, tmp_targets
-
CE/CE_LOSS
def dice_loss_with_CE(beta=1, smooth = 1e-5): def _dice_loss_with_CE(y_true, y_pred): y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) CE_loss = - y_true[...,:-1] * K.log(y_pred) CE_loss = K.mean(K.sum(CE_loss, axis = -1)) tp = K.sum(y_true[...,:-1] * y_pred, axis=[0,1,2]) fp = K.sum(y_pred , axis=[0,1,2]) - tp fn = K.sum(y_true[...,:-1], axis=[0,1,2]) - tp score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) score = tf.reduce_mean(score) dice_loss = 1 - score # dice_loss = tf.Print(dice_loss, [dice_loss, CE_loss]) return CE_loss + dice_loss return _dice_loss_with_CE def CE(): def _CE(y_true, y_pred): y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon()) CE_loss = - y_true[...,:-1] * K.log(y_pred) CE_loss = K.mean(K.sum(CE_loss, axis = -1)) # dice_loss = tf.Print(CE_loss, [CE_loss]) return CE_loss return _CE
详情请参考:语义分割损失函数总结
-
记录训练损失
class LossHistory(keras.callbacks.Callback): def __init__(self, log_dir): import datetime curr_time = datetime.datetime.now() time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S') self.log_dir = log_dir self.time_str = time_str self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str)) self.losses = [] self.val_loss = [] os.makedirs(self.save_path) def on_epoch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) self.val_loss.append(logs.get('val_loss')) with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f: f.write(str(logs.get('loss'))) f.write("\n") with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f: f.write(str(logs.get('val_loss'))) f.write("\n") # self.loss_plot() def loss_plot(self): iters = range(len(self.losses)) plt.figure() plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') try: if len(self.losses) < 25: num = 5 else: num = 15 plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') except: pass plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('A Loss Curve') plt.legend(loc="upper right") plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png")) plt.cla() plt.close("all")
Predict
训练的过程中只保留了权重,因此需要先实现网络,然后加载权重。模型代码如下:
class Unet(object): #---------------------------------------------------# # 初始化UNET #---------------------------------------------------# def __init__(self, **kwargs): _defaults = { "model_path" : kwargs['model'], "model_image_size" : kwargs['model_image_size'], "num_classes" : kwargs['num_classes'] } self.__dict__.update(_defaults) self.generate() #---------------------------------------------------# # 载入模型 #---------------------------------------------------# def generate(self): #-------------------------------# # 载入模型与权值 #-------------------------------# self.model = unet(self.model_image_size, self.num_classes) self.model.load_weights(self.model_path) print('{} model loaded.'.format(self.model_path)) if self.num_classes <= 21: self.colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)] else: # 画框设置不同的颜色 hsv_tuples = [(x / len(self.class_names), 1., 1.) for x in range(len(self.class_names))] self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) self.colors = list( map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors)) def letterbox_image(self ,image, size): image = image.convert("RGB") iw, ih = image.size w, h = size scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', size, (128,128,128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) return new_image,nw,nh #---------------------------------------------------# # 检测图片 #---------------------------------------------------# def detect_image(self, image): #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 #---------------------------------------------------------# image = image.convert('RGB') #---------------------------------------------------# # 对输入图像进行一个备份,后面用于绘图 #---------------------------------------------------# old_img = copy.deepcopy(image) orininal_h = np.array(image).shape[0] orininal_w = np.array(image).shape[1] #---------------------------------------------------# # 进行不失真的resize,添加灰条,进行图像归一化 #---------------------------------------------------# img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0])) img = np.asarray([np.array(img)/255]) pr = self.model.predict(img)[0] #---------------------------------------------------# # 取出每一个像素点的种类 #---------------------------------------------------# pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]]) #--------------------------------------# # 将灰条部分截取掉 #--------------------------------------# pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)] #------------------------------------------------# # 创建一副新图,并根据每个像素点的种类赋予颜色 #------------------------------------------------# seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3)) for c in range(self.num_classes): seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8') seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8') seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8') image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h), Image.NEAREST) blend_image = Image.blend(old_img,image,0.7) return image, blend_image def get_FPS(self, image, test_interval): orininal_h = np.array(image).shape[0] orininal_w = np.array(image).shape[1] img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0])) img = np.asarray([np.array(img)/255]) pr = self.model.predict(img)[0] pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]]) pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)] image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST) t1 = time.time() for _ in range(test_interval): pr = self.model.predict(img)[0] pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]]) pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)] image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST) t2 = time.time() tact_time = (t2 - t1) / test_interval return tact_time