TensorFlow下使用YOLOv1训练自己的数据集+测试自己的模型
一. 前期准备
环境:(用cpu跑的)
win10 + python3.6.8 + tensorflow2.4.1+pycharm
ps:本来打算用tensflow-gpu 1.4.2运行的,但是该代码是2.xx版本的tensflow,需要安装tensflow-gpu 2.x.版本,以及cuda 11.0版本的 ,由于本人电脑比较差,最高cuda安装到8.0版本,又不想升级显卡驱动,所以用cpu跑的。
YOLOv1代码:
1.1准备自己要训练的图片
first 在data文件夹中创建一个myDataSet文件夹来保存爬取的图片,然后从myDataSet文件夹中挑选出符合自己数据集(比如我想训练只检测手机,所以挑选带手机的图片)的图片,删除不符合的图片
second 再接着 运行SelectSize.py文件,继续筛选文件
last 最后运行graph_rename.py文件,给上述筛选的文件命名,得到最终的原图片
我本人在网上爬取的图片:爬取代码以及其他代码如下:
GitHub - xucancan1617608769/yolov1
1.2 数据集准备
在Yolo_tensorflow\data\pascal_voc\VOCdevkit\VOC2007路径下,创建如下文件夹(ps:我的文件夹中已经创好了)
将1.1中data文件夹中的图片全部copy到此路径下的JPEGImages文件夹中,然后在ImageSets文夹下创建Main文件夹,接着用labelImg工具对JPEGImages文件夹的图片进行数据标记保存到Annotations文件夹下
工具:LabelImg 链接:链接:https://pan.baidu.com/s/1qwPnIx-T_-Kl1CawtE4M9g
提取码:8cb2
labelImg工具的使用参考链接:使用LabelImg标注图片 - AiFly - 博客园
像这样:
运行conert_to_txt.py文件生成ImageSets/Main/文件夹下的4个文件。(layout文件夹不用管)
像这样:
运行完后Main文件夹会生成yolov1所需的train.txt,val.txt,test.txt,trainval.txt
至此,自己要训练的数据集已准备完毕
二.修改配置
①修改 ..\yolo_tensorflow\yolo\config.py文件 如下
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
'train', 'tvmonitor']
====>>CLASSES = ['phone', 'cat', 'dog', 'sheep'] # 修改为自己的类别
② ..\yolo_tensorflow\utils\pascal_voc.py文件
#第30行
labels = np.zeros(
(self.batch_size, self.cell_size, self.cell_size, 9)) # 25修改为5+类别数
#第129行
label = np.zeros((self.cell_size, self.cell_size, 9)) # 修改为5+类别数
开始训练,运行 train.py文件。
三 .保存权重文件 测试
① 将..\yolo_tensorflow\data\pascal_voc\output\2021_09_12_10_47\目录下(2021_09_12_10_47文件夹在模型开始训练时自动生成)保存的最后的模型文件(下图1)复制到..\yolo_tensorflow\data\weights\目录下(下图2),并在名称中间添加“.ckpt”重新命名。
上图1
上图2
② 修改..\yolo_tensorflow\test.py文件。
将待测试图片放入..\yolo_tensorflow-master\test\目录下。
# 在189行的默认权重文件改为自己训练好权重的模型
def main():
parser = argparse.ArgumentParser()
# parser.add_argument('--weights', default="YOLO_small.ckpt", type=str)
parser.add_argument('--weights', default="yolo-15000.ckpt", type=str)
# 将207行需要测试的图片名称修改为自己的图片名称
# detect from image file
imname = 'test/000004.jpg'
detector.image_detector(imname)
③ 运行test.py文件进行测试
结果:
四.结语
整个流程按照步骤走了一遍,中间需要安装一些第三方安装包,比如tensorflow2.x.x版本,以及cv2,tf_sim,最后结果不是很理想,定位存在很大误差,以后再更新改进。