对于语义分割来说,网络输出的图像为HxW的二维矩阵,其上面每个像素点的值是这个像素点的类别(如,像素点值为1 ,表示这个像素点属于第一类)。然而,对于一个二维矩阵,生成的图像是一个灰度图,并且灰度值很低,非常不利于人观察(如下图为voc2007的标签,图中白色为人为标记的,真正的网络预测并没有这一部分)
那么我们就需要对其进行染色处理,使其变成利于人观察的图像(如下图这样)
对图像染色有很多方法,下面介绍一种最简单的一种方法:
def cam_mask(mask,palette,n):
seg_img = np.zeros((np.shape(mask)[0], np.shape(mask)[1], 3))
for c in range(n):
seg_img[:, :, 0] += ((mask[:, :] == c) * (palette[c][0])).astype('uint8')
seg_img[:, :, 1] += ((mask[:, :] == c) * (palette[c][1])).astype('uint8')
seg_img[:, :, 2] += ((mask[:, :] == c) * (palette[c][2])).astype('uint8')
colorized_mask = Image.fromarray(np.uint8(seg_img))
return colorized_mask
利用这个函数就可以将网络预测结果生成彩色图像,其中mask为预测结果,palette为染色版,即你需要用什么颜色进行染色,是一个列表加元组的形式,n为网络预测的类别。
下面拿voc数据集举例,它的染色板为:
palette = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
下面简单解释一下这个代码的含义:
对于输入HxW二维预测结果,我们先生成一个HxWx3的全零矩阵seg_img。
然后从0到类别数(21)开始循环,如果预测结果中有与类别数c相同的值,那么这个位置的值为1,否在为0。这样会生成一个掩码,这个掩码的对应为1的位置就是预测结果中属于第c个类别的位置。然后,我们将染色板的三个值分别加到之前的seg_img的三个通道上,这样就形成了HxWxC的RGB三个通道的彩色图像。