模型训练及测试
一、在DeepLabv3+模型的基础上,主要需要修改以下两个文件
data_generator.py
train_utils.py
(1)添加数据集描述
在datasets/data_generator.py文件中,添加自己的数据集描述:_CAMVID_INFORMATION = DatasetDescriptor( splits_to_sizes={ 'train': 1035, 'val': 31,}, num_classes=3, ignore_label=255, )自己的数据集共有3个classes,算上了background。由于没有使用 ignore_label , 没有算上ignore_label
(2)注册数据集
_DATASETS_INFORMATION = { 'cityscapes': _CITYSCAPES_INFORMATION, 'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION, 'ade20k': _ADE20K_INFORMATION, 'camvid':_CAMVID_INFORMATION, # 'mydata':_MYDATA_INFORMATION, }
(3)修改train_utils.py
对应的utils/train_utils.py中,将210行关于 exclude_list 的设置修改,作用是在使用预训练权重时候,不加载该 logit 层:
exclude_list = ['global_step','logits'] if not initialize_last_layer: exclude_list.extend(last_layers)
如果想在DeepLab的基础上fifine-tune其他数据集, 可在deeplab/train.py中修改输入参数。
一些选项: 使用预训练的所有权重,设置initialize_last_layer=True 只使用网络的backbone,设置initialize_last_layer=False和 last_layers_contain_logits_only=False 使用所有的预训练权重,除了logits。因为如果是自己的数据集,对应的classes不同(这个我们前面已经设置不加载logits),可设置initialize_last_layer=False和ast_layers_contain_logits_only=True 这里使用的设置是: initialize_last_layer=False #157行 last_layers_contain_logits_only=True #160行二、网路训练
(1)下载预训练模型
下载地址:https://github.com/tensorflflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
下载到deeplab目录下,然后解压: tar -zxvf deeplabv3_cityscapes_train_2018_02_06.tar.gz 需要注意对应的解压文件目录为:/lwh/models/research/deeplab/deeplabv3_cityscapes_train
(2)类别不平衡修正
blackboard分割项目案例中的数据集,因为是3分类问题,其中background占了非常大的比例,设置的 权重比例为1,3,3, 注意:权重的设置对最终的分割性能有影响。权重的设置因数据集而异。 在common.py的145行修改权重如下:flags.DEFINE_multi_float( 'label_weights', [1.0,3.0,3.0], 'A list of label weights, each element represents the weight for the label ' 'of its index, for example, label_weights = [0.1, 0.5] means the weight ' 'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all ' 'the labels have the same weight 1.0.')
(3)训练
注意如下几个参数: train_logdir: 训练产生的文件存放位置 dataset_dir: 数据集的TFRecord文件 dataset:设置为在data_generator.py文件设置的数据集名称 在自己的数据集上的训练指令如下: 在目录 ~/models/research/deeplab下执行python train.py --training_number_of_steps=30000 --train_split="train" --model_variant="xception_65"
--atrous_rates=6 --atrous_rates=12 --atrous_rates=18 --output_stride=16 --decoder_output_stride=4
--train_crop_size=801,801 --train_batch_size=2 --dataset="camvid"
--tf_initial_checkpoint='/lwh/models/research/deeplab/deeplabv3_cityscapes_train/model.ckpt'
--train_logdir='/lwh/models/research/deeplab/exp/blackboard_train/train'
--dataset_dir='/lwh/models/research/deeplab/datasets/blackboard/tfrecord'
设置train_crop_size原则:
output_stride * k + 1, where k is an integer. For example, we have 321x321,513x513,801x801(4)模型导出
python export_model.py \ --logtostderr \ --checkpoint_path="/lwh/models/research/deeplab/exp/blackboard_train/train/model.ckpt-30000" \ --export_path="/lwh/models/research/deeplab/exp/blackboard_train/train/frozen_inference_graph.pb" \ --model_variant="xception_65" \ --atrous_rates=6 \ --atrous_rates=12 \ --atrous_rates=18 \ --output_stride=16 \ --decoder_output_stride=4 \ --num_classes=3 \ --crop_size=1080 \ --crop_size=1920 \ --inference_scales=1.0
注意几点:
--checkpoint_path 为自己模型保存的路径
--export_path 模型导出保存的路径
--num_classes=3 自己数据的类别数包含背景
--crop_size=1080 第一个为模型要求输入的高h
--crop_size=1920 第一个为模型要求输入的宽w
三、模型测试
直接上代码
# !--*-- coding:utf-8 --*-- # Deeplab Demo import os import tarfile from matplotlib import gridspec import matplotlib.pyplot as plt import numpy as np from PIL import Image import tempfile from six.moves import urllib import tensorflow as tf class DeepLabModel(object): """ 加载 DeepLab 模型; 推断 Inference """ INPUT_TENSOR_NAME = 'ImageTensor:0' OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' INPUT_SIZE = 1920 FROZEN_GRAPH_NAME = 'frozen_inference_graph' def __init__(self, tarball_path): """ Creates and loads pretrained deeplab model. """ self.graph = tf.Graph() graph_def = None graph_def = tf.GraphDef.FromString(open(tarball_path, 'rb').read()) if graph_def is None: raise RuntimeError('Cannot find inference graph in tar archive.') with self.graph.as_default(): tf.import_graph_def(graph_def, name='') self.sess = tf.Session(graph=self.graph) def run(self, image): """ Runs inference on a single image. Args: image: A PIL.Image object, raw input image. Returns: resized_image: RGB image resized from original input image. seg_map: Segmentation map of `resized_image`. """ width, height = image.size resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) target_size = (int(resize_ratio * width), int(resize_ratio * height)) target_size = (1920,1080) resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) print(resized_image) batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME, feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) seg_map = batch_seg_map[0] return resized_image, seg_map def create_pascal_label_colormap(): """ Creates a label colormap used in PASCAL VOC segmentation benchmark. Returns: A Colormap for visualizing segmentation results. """ colormap = np.zeros((256, 3), dtype=int) ind = np.arange(256, dtype=int) for shift in reversed(range(8)): for channel in range(3): colormap[:, channel] |= ((ind >> channel) & 1) << shift ind >>= 3 return colormap def label_to_color_image(label): """ Adds color defined by the dataset colormap to the label. Args: label: A 2D array with integer type, storing the segmentation label. Returns: result: A 2D array with floating type. The element of the array is the color indexed by the corresponding element in the input label to the PASCAL color map. Raises: ValueError: If label is not of rank 2 or its value is larger than color map maximum entry. """ if label.ndim != 2: raise ValueError('Expect 2-D input label') colormap = create_pascal_label_colormap() if np.max(label) >= len(colormap): raise ValueError('label value too large.') return colormap[label] def vis_segmentation(image, seg_map): """Visualizes input image, segmentation map and overlay view.""" plt.figure(figsize=(15, 5)) grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1]) plt.subplot(grid_spec[0]) plt.imshow(image) plt.axis('off') plt.title('input image') plt.subplot(grid_spec[1]) seg_image = label_to_color_image(seg_map).astype(np.uint8) plt.imshow(seg_image) plt.axis('off') plt.title('segmentation map') plt.subplot(grid_spec[2]) plt.imshow(image) plt.imshow(seg_image, alpha=0.7) plt.axis('off') plt.title('segmentation overlay') unique_labels = np.unique(seg_map) ax = plt.subplot(grid_spec[3]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest') ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0.0) plt.grid('off') plt.show() LABEL_NAMES = np.asarray( ['background', 'blackboard','screen']) # LABEL_NAMES = np.asarray( # ['background', 'blackboard','screen']) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) download_path = r"D:\python_project\deeplabv3+\blackboard_v2.pb" MODEL = DeepLabModel(download_path) print('model loaded successfully!') ## def run_visualization(imagefile): """ DeepLab 语义分割,并可视化结果. """ orignal_im = Image.open(imagefile) print('running deeplab on image %s...' % imagefile) resized_im, seg_map = MODEL.run(orignal_im) print(seg_map.shape) vis_segmentation(resized_im, seg_map) images_dir = r'D:\python_project\deeplabv3+\test_img' # 测试图片目录所在位置 images = sorted(os.listdir(images_dir)) for imgfile in images: run_visualization(os.path.join(images_dir, imgfile)) print('Done.')
需要注意的两点:
1.images_dir 修改为自己存图片的dir
2.INPUT_SIZE = 1920修改自己图片的hw最大的一个
测试结果展示