smallcorgi/Faster-RCNN_TF训练自己的数据

  熟悉了github项目提供的训练测试后,可以来训练自己的数据了。本文只介绍改动最少的方法,只训练2个类,

即自己添加的类(如person)和 background,使用的数据格式为pascal_voc。

1.训练数据的准备
  先来看看data下的目录:

  smallcorgi/Faster-RCNN_TF训练自己的数据

  (1)Annotations 存放所有训练数据的xml文件,是图片的标注数据,

可以使用labelImg工具生成。github地址:https://github.com/tzutalin/labelImg.git
  (2)ImageSets 底下有个main文件夹,里面放的是4个txt文件,

分别为  test.txt,train.txt,trainval.txt,val.txt。

每个文件存放的都是相应的图片数据名称,不含后缀。
trainval是train和val的合集,两者的比例可以为1:1。

生成txt文件的方法可以参考本人的另一篇blog:http://www.cnblogs.com/danpe/p/7859635.html
  (3)JPEGImages 是存放所有训练图片的目录。

注:修改为训练数据后,需要删除data/cache 下的pkl文件,不然不会去获取修改的数据,而是使用该缓存。
2.修改项目部分代码文件
  由于我们只训练了2个类,所以需要对代码中有关类的数目的地方进行修改。
  (1)lib/datasets/pascal_voc.py

   class pascal_voc(imdb):
def __init__(self, image_set, year, devkit_path=None):
imdb.__init__(self, 'voc_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._devkit_path = self._get_default_path() if devkit_path is None \
else devkit_path
self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
# modified
# self._classes = ('__background__', # always index 0
# 'aeroplane', 'bicycle', 'bird', 'boat',
# 'bottle', 'bus', 'car', 'cat', 'chair',
# 'cow', 'diningtable', 'dog', 'horse',
# 'motorbike', 'person', 'pottedplant',
# 'sheep', 'sofa', 'train', 'tvmonitor')
self._classes = ('__background__', # always index 0
'person')

  (2)lib/datasets/pascal_voc2.py,与pascal_voc.py文件类似。
  (3)lib/networks/VGGnet_train.py

    import tensorflow as tf
from networks.network import Network #define

# modified
#n_classes = 21
n_classes = 2
_feat_stride = [16,]
anchor_scales = [8, 16, 32]

  (4)lib/networks/VGGnet_test.py,与VGGnet_train.py文件类似。
  (5)tools/demo.py

 import os, sys, cv2
import argparse
from networks.factory import get_network

# modified
#CLASSES = ('__background__',
# 'aeroplane', 'bicycle', 'bird', 'boat',
# 'bottle', 'bus', 'car', 'cat', 'chair',
# 'cow', 'diningtable', 'dog', 'horse',
# 'motorbike', 'person', 'pottedplant',
# 'sheep', 'sofa', 'train', 'tvmonitor') CLASSES = ('__background__',
'person')

  注:如果修改的.py文件有对应的.pyc文件,需要对pyc文件重新编译,方法为

  import py_compile

  py_compile.compile(dir/filename)

    
3.执行训练的脚本
    ./experiments/scripts/faster_rcnn_end2end.sh $DEVICE $DEVICE_ID VGG16 pascal_voc

上一篇:mysql 插入百万条数据


下一篇:Java多线程系列--“JUC线程池”05之 线程池原理(四)