训练步骤
-
安装
labelme
conda create --name=labelme python=3.6 conda activate labelme pip install pyqt5 pip install labelme
-
安装
scikit-image
,scipy
pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple pip install scipy -i https://pypi.tuna.tsinghua.edu.cn/simple
-
运行
..\data\coco\train2014\
中labelme2coco.py
,生成instances_train2014.json
,并复制到..\data\coco\annotations\
中 -
运行
..\data\coco\val2014\
中labelme2coco.py
,生成instances_val2014.json
,并复制到..\data\coco\annotations\
中 -
修改
change.py
并运行:pretrained_weights = torch.load('F:\\BaiduNetdiskDownload\\mmdetection-1.1.0\\checkpoints\\faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth') num_class = 类别 + 1 torch.save(pretrained_weights, "F:\\BaiduNetdiskDownload\\mmdetection-1.1.0\\faster_rcnn_r50_fpn_1x_%d.pth"%num_class)
可以看到文件夹里多出一个文件:
faster_rcnn_r50_fpn_1x_2.pth
-
mmdetection-1.1.0\mmdet\utils\collect_env.py
注释42-44
行#gcc = subprocess.check_output('gcc --version | head -n1', shell=True) #gcc = gcc.decode('utf-8').strip() #env_info['GCC'] = gcc
-
修改
mmdetection-1.1.0\configs\My_faster_rcnn_r50_fpn_1x.py
:训练类别数: num_classes = 类别+1 训练次数: total_epochs = 10 学习率: lr = 0.02 第2步生成的权重文件: load_from = 'F:\\BaiduNetdiskDownload\\mmdetection\\mmdetection-1.1.0\\faster_rcnn_r50_fpn_1x_2.pth' 训练多少次保存一次权重: checkpoint_config = dict(interval=5) 填写你的val2014文件夹图片数目: log_config = dict( interval=10, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook') ])
-
修改
mmdetectio_fasterrcnn\mmdet\datasets\coco.py
中CLASSES
为自己要训练的类别名称:CLASSES = ('dog')
-
开始训练:
python tools/train.py configs/My_faster_rcnn_r50_fpn_1x.py
-
这一步可能会报错,不用理会:
'tail' 不是内部或外部命令,也不是可运行的程序或批处理文件。 'gcc' 不是内部或外部命令,也不是可运行的程序或批处理文件。
-
但如果这一步报错是:
OSError: symbolic link privilege not held
,需要以管理员模式打开进行运行
测试训练结果
-
训练模型在:
F:\BaiduNetdiskDownload\mmdetection\mmdetection-1.1.0\work_dirs\faster_rcnn_r50_fpn_1x\
-
修改
demo.py
:config_file = 'configs/My_faster_rcnn_r50_fpn_1x.py' checkpoint_file = 'work_dirs/faster_rcnn_r50_fpn_1x/epoch_10.pth' img = 'data/coco/val2014/000021.jpg'
-
运行
demo.py
:python demo.py
,这一步可能会报错:Traceback (most recent call last): File "demo.py", line 16, in <module> show_result(img, result, model.CLASSES) File "F:\BaiduNetdiskDownload\mmdetection\mmdetection-1.1.0\mmdet\apis\inference.py", line 143, in show_result assert isinstance(class_names, (tuple, list)) AssertionError