yolov5-V6 ->ONNX->TensorRT:
- ONNX最终结果正确
- TensorRT最终结果不正确
解决方案
生成仅提取特征图, 无需后续Detect()模块
1.yolo.py
class Detect
def forward(self, x):
z = [] # inference output
# =====新增部分==============
onnx_export=True
if onnx_export:
print("=======bobo====")
for i in range(self.nl):
x[i] = self.m[i](x[i])
bs, _, ny, nx = x[i].shape # x(bs,48,20,20) to x(bs,3,20,20,16)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
return x
# ===================
- 生成onnx文件时一定要简化
export.py
parser.add_argument('--simplify', default=True, help='ONNX: simplify model')
- ONNX->TensorRT
- Torch后处理
onnx输出三个尺度不同的特征图,torch进行生成anchor等后处理
import numpy as np
import torch
class Detect():
def __init__(self,device="cuda:0"):
self.device=device
self.na=3 # 一个网格预测的anchors数
self.nl=3 # 检测层的网络层数
self.no=7 # 4坐标+1置信度+2类别
self.stride=torch.Tensor([8.,16.,32.]).to(device)
# anchors # anchors=[P3/8,P4/16,P5/32]
anchors_yaml=torch.Tensor([[10,13, 16,30, 33,23], [30,61, 62,45, 59,119],[116,90, 156,198, 373,326]]).to(device)
self.anchors=(anchors_yaml / self.stride[...,None]).view(self.nl, -1, 2)
# 初始化
self.anchor_grid = [torch.zeros(1).to(device)] * self.nl # init anchor grid
self.grid = [torch.zeros(1).to(device)] * self.nl
def after_process(self,x):
z=[]
for i in range(len(x)):
bs, _, ny, nx, _, = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
# 已交换维度
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
return torch.cat(z, 1)
def _make_grid(self, nx=20, ny=20, i=0):
d = self.anchors[i].device
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
.view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()
return grid, anchor_grid
# trt_result为onnx输出
device="cuda:0"
trt_result=np.load("/code/lipengbo/SexyDet/yolov5-v6/checkpoint/trt_result.npy",allow_pickle=True).tolist()
x=[torch.from_numpy(trt_result[0].reshape([1,3,64,64,7])).to(device),torch.from_numpy(trt_result[1].reshape([1,3,32,32,7])).to(device),torch.from_numpy(trt_result[2].reshape([1,3,16,16,7])).to(device)]
detect=Detect()
final_result=detect.after_process(x)
print()