一. Object Detection API的安装
参考链接:【TensorFlow学习二】Object Detection API的安装与使用
二. 准备工作
1. 在object_detection/
目录下创建目录ssd_model
mkdir object_detection/ssd_model
把下载好的数据集VOCdevkit解压进去,数据集路径为
./object_detection/ssd_model/VOCdevkit/
2. 执行配置文件
python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=train --output_path=object_detection/ssd_model/pascal_train.record
python ./object_detection/dataset_tools/create_pascal_tf_record.py --label_map_path=object_detection/data/pascal_label_map.pbtxt --data_dir=object_detection/ssd_model/VOCdevkit/ --year=VOC2012 --set=val --output_path=object_detection/ssd_model/pascal_val.record
3. 然后会在ssd_model/
目录下生成pascal_train.record
和pascal_val.record
两个文件,分别有670M左右。 下一步复制训练pet
数据用到的文件,我们在这个基础上修改配置,训练我们的数据
cp object_detection/data/pascal_label_map.pbtxt object_detection/ssd_model/
cp object_detection/samples/configs/ssd_mobilenet_v1_pets.config object_detection/ssd_model/
4. 我们打开pascal_label_map.pbtxt看一下,这个文件里面是类似Json格式的label集,列出了数据集里有哪些label
。Pascal这个数据集label
共有20个。
然后打开配置文件ssd_mobilenet_v1_pets.config,把num_classes
改为20
配置默认训练次数num_steps: 200000
,我们根据自己需要改,注意这个训练是很慢的,差不多以天为单位,所以可以适当改小点。
5. 接下来改一些文件路径:
train_input_reader: {
tf_record_input_reader {
input_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_train.record"
}
label_map_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
}
eval_input_reader: {
tf_record_input_reader {
input_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_val.record"
}
label_map_path: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_label_map.pbtxt"
shuffle: false
num_readers: 1
}
将下载的ssd_mobilenet_v1_coco_2018_01_28解压到 /object_detection/ssd_model/ssd_mobilenet
把路径填进配置文件ssd_mobilenet_v1_pets.config
fine_tune_checkpoint: "/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/ssd_mobilenet/model.ckpt"
完成之后,就可以训练模型了。
三. 训练模型
cd models/research
python object_detection/legacy/train.py --train_dir object_detection/ssd_model/train --pipeline_config_path object_detection/ssd_model/ssd_mobilenet_v1_pets.config
经过漫长的等待,可以看到在/object_detection/ssd_model/train
目录下生成了模型。
然后创建文件夹ssd_model/model
python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path object_detection/ssd_model/ssd_mobilenet_v1_pets.config --trained_checkpoint_prefix object_detection/train/model.ckpt-50000 --output_directory object_detection/ssd_model/model/
生成pb文件,再把pascal_label_map.pbtxt的内容改成.txt
作为label文件,这个模型就可以使用了。
四. 测试模型
import cv2
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
class TOD(object):
def __init__(self):
self.PATH_TO_CKPT = '/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/model/frozen_inference_graph.pb'
self.PATH_TO_LABELS = '/data2/zzw/Tensorflow-test/models/research/object_detection/ssd_model/pascal_label_map.pbtxt'
self.NUM_CLASSES = 1
self.detection_graph = self._load_model()
self.category_index = self._load_label_map()
def _load_model(self):
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
def _load_label_map(self):
label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map,
max_num_classes=self.NUM_CLASSES,
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
return category_index
def detect(self, image):
with self.detection_graph.as_default():
with tf.Session(graph=self.detection_graph) as sess:
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image, axis=0)
image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
self.category_index,
use_normalized_coordinates=True,
line_thickness=8)
cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
cv2.imshow("detection", image)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread('/home/wow/Github/models/research/object_detection/ssd_model/img/test.jpg')
detecotr = TOD()
detecotr.detect(image)
参考链接:[Tensorflow] 使用SSD-MobileNet训练模型
Tensorflow物体检测(Object Detection)API的使用
Tensorflow detection model zoo