mmdet开源项目,有很多人使用,但对于初学者,不知如何调用,尽管能够训练,却有时也为测试代码泛难,趁着项目需要,我个人编写了一份测试代码,共大家学习。
准备:
1.config parameters 参数py文件
2.checkpoint 训练好的权重.pth文件
3.classes 若无,则数字替代
4.测试图片路径 该文件存放需要测试图片,也可以是多个文件下存放,代码会自动寻找
详细如下:
from mmdet.apis import inference_detector, init_detector
import cv2
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
class Model():
def __init__(self, root_config, root_checkpoint, **kwargs):
self.model = init_detector(root_config, root_checkpoint) # 模型初始化
self.thr_ok = kwargs.get('thr_ok', 0.05)
self.classes = kwargs.get('classes', None)
self.color = self.get_color()
self.img_foramt = ['.jpg', '.JPG', '.bmp', '.png']
def get_color(self):
color = dict(red=(0, 0, 255),
green=(0, 255, 0),
blue=(255, 0, 0),
cyan=(255, 255, 0),
yellow=(0, 255, 255),
magenta=(255, 0, 255),
white=(255, 255, 255),
black=(0, 0, 0))
return color
def model_test(self, result, img_name, classes, thr_ok=0.05):
output_bboxes = []
json_dict = []
total_bbox = []
for id, boxes in enumerate(result): # loop for categories
category_id = id + 1
if len(boxes) != 0:
for box in boxes: # loop for bbox
conf = box[4]
if conf > thr_ok:
total_bbox.append(list(box) + [category_id])
bboxes = np.array(total_bbox)
best_bboxes = bboxes
output_bboxes.append(best_bboxes)
for bbox in best_bboxes:
coord = [round(i, 2) for i in bbox[:4]]
conf = bbox[4]
category = classes[int(bbox[5]) - 1] if classes is not None else int(bbox[5])
json_dict.append({'img_name': img_name, 'cats': category, 'bbox': coord, 'score': conf})
det_df = pd.DataFrame(json_dict, columns=['img_name', 'cats', 'bbox', 'score'])
return det_df
def single_test(self, root_img): # 单张图片模型测试
img = cv2.imread(root_img)
model_result = inference_detector(self.model, img)
return model_result
def run(self, img_root):
model_result = self.single_test(img_root)
img_name = self.get_strfile(img_root, pos=-1)
result_df = self.model_test(model_result, img_name, self.classes, thr_ok=self.thr_ok)
img_name_lst, cat_lst, box_lst, score_lst = self.pd2lst(result_df)
return img_name_lst, cat_lst, box_lst, score_lst
def pd2lst(self, result_df):
img_name_lst, cat_lst, box_lst, score_lst = [], [], [], []
if len(result_df) > 0:
for i in range(len(result_df)):
img_name_lst.append(result_df.loc[i]['img_name'])
cat_lst.append(result_df.loc[i]['cats'])
box_lst.append(result_df.loc[i]['bbox'])
score_lst.append(result_df.loc[i]['score'])
return img_name_lst, cat_lst, box_lst, score_lst
def draw_bbox(self, img, cat_lst, box_lst, score_lst,
bbox_color='green',
text_color='green',
thickness=1,
font_scale=0.5
):
for j, cat in enumerate(cat_lst):
x1, y1, x2, y2 = np.array(box_lst[j]).astype(np.int32)
bbox_color_new = self.color[bbox_color]
cv2.rectangle(img, (x1, y1), (x2, y2), bbox_color_new, thickness=thickness)
score = round(score_lst[j], 4)
text_color_new = self.color[text_color]
label_text = '{}:{}'.format(str(cat), str(score))
cv2.putText(img, label_text, (x1, y1 - 2), cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color_new)
return img
def get_strfile(self, file_str, pos=-1):
# 得到file_str / or \\ 的最后一个名称
endstr_f_filestr = file_str.split('\\')[pos] if '\\' in file_str else file_str.split('/')[pos]
return endstr_f_filestr
def build_dir(self,out_dir):
# 构建文件
if not os.path.exists(out_dir):
os.makedirs(out_dir)
return out_dir
def show_img(self,img):
import matplotlib.pyplot as plt
plt.imshow(img)
plt.show()
def get_files_root(root):
'''
:return: 寻找root下面有文件夹的路径,输出所有文件夹绝对路径的列表
'''
files_lst = [root]
result_lst = files_lst
if build_files(root) == []:
result_lst = files_lst
else:
is_while = True
files_all_path = [file for file in files_lst]
while is_while:
for file_root in files_lst:
F1 = build_files(file_root)
for F1 in F1:
files_all_path.append(F1)
is_while = False
# 排除主文件夹
record = np.ones((len(files_all_path)))
for i, F3 in enumerate(files_all_path):
F3 = files_all_path[i]
for j, F4 in enumerate(files_all_path):
if F3 + '\\' in F4 or F3 + '/' in F4:
record[i] = 0
break
# 将需要循环聚集
files_lst = []
for i, F3 in enumerate(files_all_path):
if record[i] == 1:
files_lst.append(F3)
# 判断是否有子文件夹
for F4 in files_lst:
file_judge = build_files(F4)
if file_judge != []:
is_while = True
break
result_lst = files_lst
return result_lst
def build_files(root):
'''
:得到该路径下的所有文件
'''
files = [os.path.join(root, file) for file in os.listdir(root)]
files_true = []
for file in files:
if not os.path.isfile(file):
files_true.append(file)
return files_true
def single_main(model,root_img,work_dir):
# 一张图片测试所有集合
img_name_lst, cat_lst, box_lst, score_lst = model.run(root_img) #
img = cv2.imread(root_img)
img = model.draw_bbox(img, cat_lst, box_lst, score_lst)
file_name = model.get_strfile(root_img, pos=-2)
out_file = model.build_dir(os.path.join(work_dir, file_name))
img_name = model.get_strfile(root_img, pos=-1)
cv2.imwrite(os.path.join(out_file, img_name), img)
def main(root,model,work_dir):
root_files=get_files_root(root)
num=0
for file_path in tqdm(root_files):
for name in tqdm(os.listdir(file_path)):
if name[-4:] in model.img_foramt:
root_img=os.path.join(file_path,name)
single_main(model,root_img,work_dir)
num+=1
print('num of images:',num)
if __name__ == '__main__':
root_config = '/data/sdv3/tangjun/xmtm/xmtm_pointer/code/model_new_meter/parameters.py'
root_checkpoint = '/data/sdv3/tangjun/xmtm/xmtm_pointer/code/model_new_meter/model.pth'
root = '/data/sdv3/tangjun/xmtm/xmtm_pointer/data/data_0512/data_step_two/train_step2/train' # 只需要测试文件夹
work_dir='/data/sdv3/cj/First_Blood/tj/code/mmdet50/123/78/90'
info = {'classes': None}
time_start = time.time()
model = Model(root_config, root_checkpoint, **info) # 类实列化,也是初始化
# img_name_lst, cat_lst, box_lst, score_lst = model.run(root_img) #单张图片的预测
main(root, model, work_dir)
time_end=time.time()
time_gap=time_end-time_start
print('time gap:',time_gap)