【TensorFlow学习三】基于TensorFlow训练测试MobileNet-SSD

一. Object Detection API的安装

参考链接:【TensorFlow学习二】Object Detection API的安装与使用

二. 准备工作

1. 在object_detection/目录下创建目录ssd_model

mkdir object_detection/ssd_model

把下载好的数据集VOCdevkit解压进去,数据集路径为

./object_detection/ssd_model/VOCdevkit/

2. 执行配置文件

python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=train --output_path=object_detection/ssd_model/pascal_train.record

python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=val --output_path=object_detection/ssd_model/pascal_val.record

3. 然后会在ssd_model/目录下生成pascal_train.recordpascal_val.record两个文件,分别有670M左右。 下一步复制训练pet数据用到的文件,我们在这个基础上修改配置,训练我们的数据

cp object_detection/data/pascal_label_map.pbtxt object_detection/ssd_model/
cp object_detection/samples/configs/ssd_mobilenet_v1_pets.config object_detection/ssd_model/

4. 我们打开pascal_label_map.pbtxt看一下,这个文件里面是类似Json格式的label集,列出了数据集里有哪些label。Pascal这个数据集label共有20个。

然后打开配置文件ssd_mobilenet_v1_pets.config,把num_classes改为20 配置默认训练次数num_steps: 200000,我们根据自己需要改,注意这个训练是很慢的,差不多以天为单位,所以可以适当改小点。

5. 接下来改一些文件路径:

train_input_reader: {
  tf_record_input_reader {
    input_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_train.record"
  }
  label_map_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
}

eval_input_reader: {
  tf_record_input_reader {
    input_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_val.record"
  }
  label_map_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
  shuffle: false
  num_readers: 1
}
将下载的ssd_mobilenet_v1_coco_2018_01_28解压到 /object_detection/ssd_model/ssd_mobilenet

把路径填进配置文件ssd_mobilenet_v1_pets.config

fine_tune_checkpoint: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/ssd_mobilenet/model.ckpt"

完成之后,就可以训练模型了。

三. 训练模型

cd models/research

python object_detection/legacy/train.py --train_dir object_detection/ssd_model/train --pipeline_config_path object_detection/ssd_model/ssd_mobilenet_v1_pets.config

经过漫长的等待,可以看到在/object_detection/ssd_model/train目录下生成了模型。

然后创建文件夹ssd_model/model

python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path object_detection/ssd_model/ssd_mobilenet_v1_pets.config --trained_checkpoint_prefix object_detection/train/model.ckpt-50000 --output_directory object_detection/ssd_model/model/

生成pb文件,再把pascal_label_map.pbtxt的内容改成.txt作为label文件,这个模型就可以使用了。

四. 测试模型

import cv2
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util


class TOD(object):
    def __init__(self):
        self.PATH_TO_CKPT = '/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/model/frozen_inference_graph.pb'
        self.PATH_TO_LABELS = '/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_label_map.pbtxt'
        self.NUM_CLASSES = 1
        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()

    def _load_model(self):
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph

    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map,
                                                                    max_num_classes=self.NUM_CLASSES,
                                                                    use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        with self.detection_graph.as_default():
            with tf.Session(graph=self.detection_graph) as sess:
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                # Actual detection.
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                # Visualization of the results of a detection.
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    self.category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)

        cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
        cv2.imshow("detection", image)
        cv2.waitKey(0)

if __name__ == '__main__':
    image = cv2.imread('/home/wow/Github/models/research/object_detection/ssd_model/img/test.jpg')
    detecotr = TOD()
    detecotr.detect(image)

 


参考链接:[Tensorflow] 使用SSD-MobileNet训练模型

     Tensorflow物体检测(Object Detection)API的使用

     Tensorflow detection model zoo

 

上一篇:从EMC VMAX AF新品发布来探讨XtremIO退市的可能性


下一篇:【Caffe学习七】Caffe-MobileNet-SSD for Object Detection