一、定义待预测数据
# 待预测图片 test_img_path = ["./img.png"] import matplotlib.pyplot as plt import matplotlib.image as mpimg img = mpimg.imread(test_img_path[0]) # 展示待预测图片 plt.figure(figsize=(10,10)) plt.imshow(img) plt.axis('off') plt.show()
返回:
若是待预测图片存放在一个文件中,如左侧文件夹所示的test.txt。每一行是待预测图片的存放路径。
代码:
with open('mask.txt', 'r') as f: try: test_img_path=[] for line in f: test_img_path.append(line.strip()) except: print('图片加载失败') print(test_img_path)
返回:
二、加载预训练模型
PaddleHub口罩检测提供了两种预训练模型,pyramidbox_lite_mobile_mask和pyramidbox_lite_server_mask。二者均是基于2018年百度发表于计算机视觉*会议ECCV 2018的论文PyramidBox而研发的轻量级模型,模型基于主干网络FaceBoxes,对于光照、口罩遮挡、表情变化、尺度变化等常见问题具有很强的鲁棒性。不同点在于,pyramidbox_lite_mobile_mask是针对于移动端优化过的模型,适合部署于移动端或者边缘检测等算力受限的设备上。
代码:
import paddlehub as hub module = hub.Module(name="pyramidbox_lite_mobile_mask") # module = hub.Module(name="pyramidbox_lite_server_mask")
三、预测
PaddleHub对于支持一键预测的module,可以调用module的相应预测API,完成预测功能。
# 口罩检测预测 visualization=True #将预测结果保存图片可视化 output_dir='detection_result' #预测结果图片保存在当前运行路径下detection_result文件夹下 results = module.face_detection(images=imgs, use_multi_scale=True, shrink=0.6, visualization=True, output_dir='detection_result/test.jpg') for result in results: print(result) # 预测结果展示 import matplotlib.image as im import matplotlib.pyplot as plt import os # 需要读取的路径 path_name = r'./detection_result' for item in os.listdir(path=path_name): img = im.imread(os.path.join(path_name, item)) plt.imshow(img) plt.show()
返回如下:
其中,label有’MASK’和’NO MASK’两种选择:'MASK’表示戴了口罩,'NO MASK表示没有佩戴口罩。‘left’/‘rigth’/‘top’/'bottom’表示口罩在图片当中的位置。'confidence’表示预测为佩戴口罩’MASK’或者不佩戴口罩’NO MASK’的概率大小。同时,作为一项完善的开源工作,除了本地推断以外,PaddleHub还支持将该预训练模型部署到服务器或移动设备中。
四.完整源码
需要文件也可以左侧联系我,当然我也是百度随便找的。
# coding=gbk """ 作者:川川 @时间 : 2021/8/30 0:14 群:970353786 """ # 待预测图片 # test_img_path = ["./img.png"] import matplotlib.pyplot as plt import matplotlib.image as mpimg # img = mpimg.imread(test_img_path[0]) # 展示待预测图片 # plt.figure(figsize=(10,10)) # plt.imshow(img) # plt.axis('off') # plt.show() with open('mask.txt', 'r') as f: try: test_img_path=[] for line in f: test_img_path.append(line.strip()) except: print('图片加载失败') print(test_img_path) # import os import cv2 # imgs =[cv2.imread(image_path) for image_path in test_img_path] imgs=[cv2.imread(test_img_path[0])] # for i in imgs: #加载模块 import paddlehub as hub module = hub.Module(name="pyramidbox_lite_mobile_mask") # module = hub.Module(name="pyramidbox_lite_server_mask") # 口罩检测预测 visualization=True #将预测结果保存图片可视化 output_dir='detection_result' #预测结果图片保存在当前运行路径下detection_result文件夹下 results = module.face_detection(images=imgs, use_multi_scale=True, shrink=0.6, visualization=True, output_dir='detection_result') for result in results: print(result) # 预测结果展示 import matplotlib.image as im import matplotlib.pyplot as plt import os # 需要读取的路径 path_name = r'./detection_result' for item in os.listdir(path=path_name): img = im.imread(os.path.join(path_name, item)) plt.imshow(img) plt.show()
如果你想放在服务器上:
执行如下命令启动模型:
hub serving start -m pyramidbox_lite_server_mask -p 8866
代码为: