利用tensorFlow api 识别手术器械

目录

数据集来源:18岁NIPS Workshop一作,用目标检测评估手术技能点

Tensorflow-model:(选中master,点击tag选择和自己tensorflow适配的版本)
https://github.com/tensorflow/models/

包含手术器械数据集及tensorflow-model所需文件
链接: https://pan.baidu.com/s/1eaUsVQEz0-SK_ADJDhw2Ng
提取码:nj4x

文件目录:

dataset\
	├─ssd_mobilenet_v1_coco_2018_01_28
	├─faster_rcnn_resnet101_coco_2018_01_28
		├── checkpoint
		├── frozen_inference_graph.pb
		├── model.ckpt.data-00000-of-00001
		├── model.ckpt.index
		├── model.ckpt.meta
		├── pipeline.config
	├─output
	├─tf_text_data
	    ├─dataset_test.record
    	├─dataset_train.record
    	└─dataset_val.record
	├─tf_text_graph
	    ├─tf_text_graph_common.py
    	├─tf_text_graph_faster_rcnn.py
    	└─tf_text_graph_ssd.py
	└─VOCdevkit
   	 	└─VOC2007
        	├─Annotations
        	├─ImageSets
        	│  └─Main
        	├─JPEGImages
        	└─classfier.py

然后将其解压到tensorflow-model文件下即可。
 

前言

识别手术器械效果:

利用tensorFlow api 识别手术器械
利用tensorFlow api 识别手术器械

tensorflow-model的配置可自行百度。
 

一、下载coco-trained models:

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md
利用tensorFlow api 识别手术器械

把下载的模型放到dataset文件夹中,例如:
faster_rcnn_resnet101_coco_2018_01_28.tar.gz

然后解压:

tar -vxf faster_rcnn_resnet101_coco_2018_01_28.tar.gz -C .

生成faster_rcnn_resnet101_coco_2018_01_28文件夹目录

faster_rcnn_resnet101_coco_2018_01_28/
├── checkpoint
├── frozen_inference_graph.pb
├── model.ckpt.data-00000-of-00001
├── model.ckpt.index
├── model.ckpt.meta
├── pipeline.config
└── saved_model
    ├── saved_model.pb
    └── variables

将里面的pipeline.config复制到dataset文件夹中,并按照本文件夹的名字重新命名(这样的目的是为了规范,当你拥有很多个模型时就不会乱)
 

二、数据集的处理

1.将数据集放到JPEGImages文件夹中,xml标签文件放到Annotations文件夹中

2.执行:

python classfier.py

生成ImageSets/Main文件夹下的三个文件:test.txt ,train.txt,val.txt以及一个trainval.txt

3.cd到dataset目录,生成tfrecord,包括验证集,测试集和训练集


python ./create_dataset_tf_record.py --data_dir=./VOCdevkit --year=VOC2007 --label_map_path=./dataset_label_map.pbtxt --set=val --output_path=./tf_text_data/dataset_val.record
python ./create_dataset_tf_record.py --data_dir=./VOCdevkit --year=VOC2007 --label_map_path=./dataset_label_map.pbtxt --set=val --output_path=./tf_text_data/dataset_test.record
python ./create_dataset_tf_record.py --data_dir=./VOCdevkit --year=VOC2007 --label_map_path=./dataset_label_map.pbtxt --set=train --output_path=./tf_text_data/dataset_train.record

4.修改dataset_label_map.pbtxt,有几类写几类,里面的标签必须和xml文件中的一致
 
 

三、修改pipeline.config文件

cd到下载模型文件中,可以看到pipeline.config
总共包括四个部分:
model:

主要修改:
	num_classes(分类的类别数)
	    image_resizer {
  fixed_shape_resizer {
    height: 300
    width: 300      (输入网络图像的大小尺寸,一般默认)
  }
}

train_config:
主要修改:

batch_size: 64(按照电脑的配置来定)
	
	############################################################################
	####       如果你想使用faster-rcnn,batch_size只能设为1      ################
	############################################################################
	
initial_learning_rate: 0.0023(初始学习率,可以自己调节)
	decay_steps: 600(每多少个steps变化一次学习率)
	decay_factor: 0.96(每次变化的学习率:current_learning_rate = decay_factor*initial_learning_rate)
	num_steps: 50000(训练的次数)
	
fine_tune_checkpoint: "#####/output/model.ckpt-18508"(迁移学习的模型文件)
	
train_input_reader:
	input_path: "#####/tf_text_data/dataset_train.record"(训练集的位置)
	label_map_path: "#####/dataset_label_map.pbtxt"(标签的定义文件)
	
eval_config:
	metrics_set:"coco_detection_metrics"(测量的方式,主要是用mAP来度量)
	num_examples: 300(验证集的数量)
	max_evals: 1(验证的循环次数)
	use_moving_averages:false(采用滑动平均)
	
eval_input_reader:
	tf_record_input_reader {
    	input_path: "#####/tf_text_data/dataset_val.record"(验证集的位置)
 	 }
	label_map_path: "#####/dataset_label_map.pbtxt"(标签的定义文件)
	

四、训练

对于ssd:

python ../research/object_detection/legacy/train.py --logtostderr --train_dir=./output --pipeline_config_path=./ssd_mobilenet_v1_coco_2018_01_28/pipeline.config

对于faster-rcnn:

python ../research/object_detection/legacy/train.py  --logtostderr --train_dir=./output --pipeline_config_path=./faster_rcnn_resnet101_coco_2018_01_28/pipeline.config

 
 

五、保存节点pb

python ../research/object_detection/export_inference_graph.py 
	input_type image_tensor 
	--pipeline_config_path ./ssd_mobilenet_v1_coco_2018_01_28/pipeline.config 
	--trained_checkpoint_prefix ./dateset/output/model.ckpt-50000 
	--output_directory ./output

六、生成pbtxt文件,opencv调用需要

生成pbtxt的py文件位于opencv源码文件/samples/dnn中
地址:
https://github.com/opencv/opencv/tree/4.5.2/samples/dnn/tf_text_graph_ssd.py


python ./tf_text_graph/tf_text_graph_ssd.py 
	--input=./output/frozen_inference_graph.pb 
	--output=./output/frozen_inference_graph.pbtxt  
	--config=./ssd_mobilenet_v1_coco_2018_01_28/pipeline.config

七、测试图片

打开object_detection_test.py文件,修改:
模型位置:

MODEL_NAME = './output'

模型名

PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

模型标签文件

PATH_TO_LABELS = './dataset_label_map.pbtxt'

类别数

NUM_CLASSES = 7

python ./object_detection_test.py --image ./VOCdevkit/VOC2007/JPEGImages/v03_062250.jpg

或者:

python ./object_detection_test.py  --data ./VOCdevkit/VOC2007/ImageSets/Main/val.txt

上一篇:(完全解决)Dataset not found. You can use download=True to download it


下一篇:【wikioi】1012 最大公约数和最小公倍数问题