yolov4项目记录4-测试过程

目录

一、概述

二、测试过程

1.参数准备

2.定义模型

3.获取必要数据

4.输入模型

5.锚框筛选

①利用物体置信度筛选

②获取种类

③按照物体置信度排序

④非极大抑制

6.画出外接框

三、代码汇总


一、概述

训练之前,我们先把预训练的数据拿过来,并输入一张图进去,看一下具体的测试流程是怎样的。这里我还是先把用到的代码分块展示, 最后放在一起。

二、测试过程

1.参数准备

在类中,我们需要定义模型,预训练参数加载的函数,解码层,以及其他需要用到的方法。我们提前准备关键字参数,将其输入模型即可。

其中,前三个分别是预训练参数、锚框大小、所有类别的文件路径。model_image_size是输入图片的大小,confidence是用来对物体检测置信度筛选的阈值,置信度大于这个阈值才会被留下,cuda是设备参数。

    params = {
        "model_path": 'pth/yolo4_weights_my.pth',
        "anchors_path": 'work_dir/yolo_anchors_coco.txt',
        "classes_path": 'work_dir/coco_classes.txt',
        "model_image_size": (608, 608, 3),
        "confidence": 0.4,
        "cuda": True
    }

    model = Inference(**params)

2.定义模型

准备好参数后,我们定义Inference这个类作为模型,在这个类里面,我们需要把预训练参数导入进去,同时把里面会用到的yolo模型,以及yolo解码模型都进行初始化。

YoloBody是我们的模型的backbone+neck+head,我们只需要提供输入的通道数,以及输出的类别数即可。最终的输出应该是有三个,分别是(1,255,19,19),(1,255,38,38),(1,255,76,76)。

中间的255,是3*(4+1+80),因为我们有80个类别,前面文章已经记录,这里不再记录细节。我们将对这三个进行解码。

YoloLayer就是我们的解码层了。前面也有记录,在__init__里面我们把需要用到的参数放进去就可以完成初始化了。这样我们就在__init__里面,得到了self.net,以及self.yolo_decodes,我们会用这两个模型,去跑测试以及解码。我们只需要在下面定义函数来调用他们即可。

需要填入的内容:图片尺寸、锚框掩膜用来筛选锚框、类别数量、先验的锚框大小、先验锚框数量、缩放因子。

class Inference(object):
    # ---------------------------------------------------#
    #   初始化模型和参数,导入已经训练好的权重
    # ---------------------------------------------------#
    def __init__(self, **kwargs):
        self.model_path = kwargs['model_path']
        self.anchors_path = kwargs['anchors_path']
        self.classes_path = kwargs['classes_path']
        self.model_image_size = kwargs['model_image_size']
        self.confidence = kwargs['confidence']
        self.cuda = kwargs['cuda']

        self.class_names = self.get_class()
        self.anchors = self.get_anchors()
        print(self.anchors)
        # =================这里是初始化模型
        self.net = YoloBody(3, len(self.class_names)).eval()
        self.load_model_pth(self.net, self.model_path)

        if self.cuda:
            self.net = self.net.cuda()
            self.net.eval()

        print('Finished!')

        self.yolo_decodes = []
        anchor_masks = [[0,1,2],[3,4,5],[6,7,8]]
        # =================这里是初始化解码部分,因为输出有三个,因此需要三个解码模型
        for i in range(3):
            head = YoloLayer(self.model_image_size, anchor_masks, len(self.class_names),
                                               self.anchors, len(self.anchors)//2).eval()
            self.yolo_decodes.append(head)


        print('{} model, anchors, and classes loaded.'.format(self.model_path))

3.获取必要数据

既然要测试,那么肯定要输入图片数据,并对图片进行处理,比如尺寸改为模型需要的608*608或其他3的倍数(因为模型里面没有使用全连接层,而是用全卷积代替了全连接层,因此对尺寸没有固定的要求。)。

要预测类别,那么我们就需要把类别名字提前准备好。

下面两个函数分别用来获取类别数据以及图片数据。图片数据用来输入模型进行预测,类别数据用来进行筛选以及标记。

def load_class_names(namesfile):
    class_names = []
    with open(namesfile, 'r') as fp:
        lines = fp.readlines()
    for line in lines:
        line = line.rstrip()
        class_names.append(line)
    return class_names

def detect_image(self, image_src):
    h, w, _ = image_src.shape
    image = cv2.resize(image_src, (608, 608))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img = np.array(image, dtype=np.float32)
    img = np.transpose(img / 255.0, (2, 0, 1))
    images = np.asarray([img])

4.输入模型

with torch.no_grad():
    images = torch.from_numpy(images)
    if self.cuda:
        images = images.cuda()
        outputs = self.net(images)

有了输出结果,就可以输入解码模块获取所有的锚框信息了。得到锚框信息后,我们把所有的都合起来。因此只有1个:(1, 22743, 85),这里面有两万多个锚框以及他们所携带的信息,要从这里面去筛选。

output_list = []
for i in range(3):
    output_list.append(self.yolo_decodes[i](outputs[i]))
output = torch.cat(output_list, 1)
print(output.shape)

5.锚框筛选

①利用物体置信度筛选

物体置信度是最后一维中的索引为4的数字。用它来和设定的阈值进行对比,这里设定的是0.5,只保留物体置信度大于0.5的锚框的信息。这里筛选后,只剩17个锚框,因此这里的image_pred形状是(17, 85)。

def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
    # 求左上角和右下角
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for image_i, image_pred in enumerate(prediction):
        # 利用物体置信度进行第一轮筛选
        conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()
        # ================得到筛选后的锚框
        image_pred = image_pred[conf_mask]

        if not image_pred.size(0):
            continue

        # 获得种类及其置信度,获得分类置信度数值以及对应索引
        class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)

②获取种类

一张图片中可能有多个物体被预测到,也就是会有多个类别,对应多个锚框,因此我们需要获取到预测的内容中,所包含的所有类别,并对每一个类别去再次进行锚框的筛选。

上面的代码中最后一行,对所有的类别进行了取最大值,通过torch.max,可以获取最大值的数值以及索引,也就是说,我们获得了这17个锚框中所携带的“类别预测分数最大的那个类别索引class_pred”,和对应的置信度class_conf。

这里就把获取到的索引和分数跟前面的锚框信息拼接起来,得到的是7维向量,根据上面的代码,我们已经将前四个维度的值分别换成了左上角x1,y1坐标,以及右下角x2,y2坐标。因此,这里得到的是(x1, y1, x2, y2, obj_conf, class_conf, class_pred)。

而最后一维,就是class_pred也就是类别的索引,对其去一个unique,就得到了所有预测类别了。这里得到了共三个类别,分别是(1,7,16)。

# 获得的内容为(x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)

# 获得种类
unique_labels = detections[:, -1].cpu().unique()

if prediction.is_cuda:
    unique_labels = unique_labels.cuda()

③按照物体置信度排序

先拿到某一类的所有预测结果,然后把物体存在置信度这里,排个序。通过torch.sort后得到的是排序好的置信度,以及他们的索引。我们使用他们的索引,去对这一类的所有预测结果排序。

这里,对第一类的预测结果总共有4个,因此形状是(4,7)。

for c in unique_labels:
    # 获得某一类初步筛选后全部的预测结果
    detections_class = detections[detections[:, -1] == c]
    # 按照存在物体的置信度排序
    _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)
    detections_class = detections_class[conf_sort_index]

④非极大抑制

这四个框,都是预测同一类别的框,那么肯定有多余的框,因此先取出来置信度最高的,和后面的三个都计算IOU,我们设定了nms_thresh是0.4,只要大于这个阈值的,说明是多余的框,就去掉,只保留最大的框。如果比这个小的话,说明这两个框交集很少,就保留。

选出了最终的锚框以及相关信息了,把它放到output中去,得到的Output是(3,7),即三个类别,每个类别对应的锚框信息。

max_detections = []
while detections_class.size(0):
    # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
    max_detections.append(detections_class[0].unsqueeze(0))
    if len(detections_class) == 1:
        break
    ious = bbox_iou(max_detections[-1], detections_class[1:])
    detections_class = detections_class[1:][ious < nms_thres]

# 堆叠
max_detections = torch.cat(max_detections).data
# Add max detections to outputs
output[image_i] = max_detections if output[image_i] is None else torch.cat(
    (output[image_i], max_detections))

计算IOU流程:

我们已经有了左上角点的x1,y1,以及右下角的x2,y2。

因为我们要计算交集的面积,并且坐标系的原点在左上角,因此我们取两个外接框左上角的最大值,以及右下角的最小值,这样就取到了交集的左上角和右下角。

但是如果直接让交集的左上角右下角相减的话,可能会出现负数,所以使用torch.clamp,规定相减的最小值为0,也就是说,如果一个框x的最大值减另一个框x的最小值是个负数,两个框没有交集,那么结果就是0。同理对y也是一样的操作。

将二者相乘,就得到了相交框的面积了。并集的面积,就是二者面积的叠加减去交集的面积。

交集/并集,这样就得到了第一个框和后面所有三个框的IOU值。

def bbox_iou(box1, box2, x1y1x2y2=True):
    """
        计算IOU
    """
    if not x1y1x2y2:
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

    inter_rect_x1 = torch.max(b1_x1, b2_x1)
    inter_rect_y1 = torch.max(b1_y1, b2_y1)
    inter_rect_x2 = torch.min(b1_x2, b2_x2)
    inter_rect_y2 = torch.min(b1_y2, b2_y2)

    inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1e-3, min=0) * \
                 torch.clamp(inter_rect_y2 - inter_rect_y1 + 1e-3, min=0)

    b1_area = (b1_x2 - b1_x1 + 1e-3) * (b1_y2 - b1_y1 + 1e-3)
    b2_area = (b2_x2 - b2_x1 + 1e-3) * (b2_y2 - b2_y1 + 1e-3)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

    return iou

6.画出外接框

将图片信息,锚框信息,类别名字,保存的文件名传进函数即可通过cv2相关接口画出外接框并保存。

def plot_boxes_cv2(img, boxes, savename=None, class_names=None, color=None):
    img = np.copy(img)
    colors = np.array([[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]], dtype=np.float32)

    def get_color(c, x, max_val):
        ratio = float(x) / max_val * 5
        i = int(math.floor(ratio))
        j = int(math.ceil(ratio))
        ratio = ratio - i
        r = (1 - ratio) * colors[i][c] + ratio * colors[j][c]
        return int(r * 255)

    width = img.shape[1]
    height = img.shape[0]
    for i in range(len(boxes)):
        box = boxes[i]
        x1 = int(box[0] * width)
        y1 = int(box[1] * height)
        x2 = int(box[2] * width)
        y2 = int(box[3] * height)

        if color:
            rgb = color
        else:
            rgb = (255, 0, 0)
        if len(box) >= 7 and class_names:
            cls_conf = box[5]
            cls_id = box[6]
            # print('%s: %f' % (class_names[cls_id], cls_conf))
            classes = len(class_names)
            offset = cls_id * 123457 % classes
            red = get_color(2, offset, classes)
            green = get_color(1, offset, classes)
            blue = get_color(0, offset, classes)
            if color is None:
                rgb = (red, green, blue)
            img = cv2.putText(img, class_names[int(cls_id)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1.2, rgb, 2)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), rgb, 3)
    if savename:
        print("save plot results to %s" % savename)
        cv2.imwrite(savename, img)
    return img

三、代码汇总

至此,我们已经走完了测试流程,所有代码分为两部分,模型和工具。如下

class Inference(object):
    # ---------------------------------------------------#
    #   初始化模型和参数,导入已经训练好的权重
    # ---------------------------------------------------#
    def __init__(self, **kwargs):
        self.model_path = kwargs['model_path']
        self.anchors_path = kwargs['anchors_path']
        self.classes_path = kwargs['classes_path']
        self.model_image_size = kwargs['model_image_size']
        self.confidence = kwargs['confidence']
        self.cuda = kwargs['cuda']

        self.class_names = self.get_class()
        self.anchors = self.get_anchors()
        print(self.anchors)
        self.net = YoloBody(3, len(self.class_names)).eval()
        self.load_model_pth(self.net, self.model_path)

        if self.cuda:
            self.net = self.net.cuda()
            self.net.eval()

        print('Finished!')

        self.yolo_decodes = []
        anchor_masks = [[0,1,2],[3,4,5],[6,7,8]]
        for i in range(3):
            head = YoloLayer(self.model_image_size, anchor_masks, len(self.class_names),
                                               self.anchors, len(self.anchors)//2).eval()
            self.yolo_decodes.append(head)


        print('{} model, anchors, and classes loaded.'.format(self.model_path))

    def load_model_pth(self, model, pth):
        print('Loading weights into state dict, name: %s' % (pth))
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(pth, map_location=device)
        matched_dict = {}

        with open('pretrained_.txt', 'w') as f:
            for k, v in pretrained_dict.items():
                f.write(k+'\n')
        with open('myparams_.txt', 'w') as f:
            for k, v in model_dict.items():
                f.write(k+'\n')


        for k, v in pretrained_dict.items():
            if np.shape(model_dict[k]) == np.shape(v):
                matched_dict[k] = v
            else:
                print('un matched layers: %s' % k)
        print(len(model_dict.keys()), len(pretrained_dict.keys()))
        print('%d layers matched,  %d layers miss' % (
        len(matched_dict.keys()), len(model_dict) - len(matched_dict.keys())))
        model_dict.update(matched_dict)
        model.load_state_dict(pretrained_dict)
        print('Finished!')
        return model

    # ---------------------------------------------------#
    #   获得所有的分类
    # ---------------------------------------------------#
    def get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    # ---------------------------------------------------#
    #   获得所有的先验框
    # ---------------------------------------------------#
    def get_anchors(self):
        anchors_path = os.path.expanduser(self.anchors_path)
        with open(anchors_path) as f:
            anchors = f.readline()
        anchors = [float(x) for x in anchors.split(',')]
        return anchors
        #return np.array(anchors).reshape([-1, 3, 2])[::-1, :, :]


    # ---------------------------------------------------#
    #   检测图片
    # ---------------------------------------------------#
    def detect_image(self, image_src):
        h, w, _ = image_src.shape
        image = cv2.resize(image_src, (608, 608))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        img = np.array(image, dtype=np.float32)
        img = np.transpose(img / 255.0, (2, 0, 1))
        images = np.asarray([img])

        with torch.no_grad():
            images = torch.from_numpy(images)
            if self.cuda:
                images = images.cuda()
            outputs = self.net(images)

        output_list = []
        for i in range(3):
            output_list.append(self.yolo_decodes[i](outputs[i]))
        output = torch.cat(output_list, 1)
        print(output.shape)
        batch_detections = non_max_suppression(output, len(self.class_names),
                                               conf_thres=self.confidence,
                                               nms_thres=0.1)
        boxes = [box.cpu().numpy() for box in batch_detections]
        print(boxes[0])
        return boxes[0]


if __name__ == '__main__':
    params = {
        "model_path": 'pth/yolo4_weights_my.pth',
        "anchors_path": 'work_dir/yolo_anchors_coco.txt',
        "classes_path": 'work_dir/coco_classes.txt',
        "model_image_size": (608, 608, 3),
        "confidence": 0.4,
        "cuda": True
    }

    model = Inference(**params)
    class_names = load_class_names(params['classes_path'])
    image_src = cv2.imread('dog.jpg')
    boxes = model.detect_image(image_src)
    plot_boxes_cv2(image_src, boxes, savename='output3.jpg', class_names=class_names)
import torch
import numpy as np
import math
import cv2


def plot_boxes_cv2(img, boxes, savename=None, class_names=None, color=None):
    img = np.copy(img)
    colors = np.array([[1, 0, 1], [0, 0, 1], [0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 0]], dtype=np.float32)

    def get_color(c, x, max_val):
        ratio = float(x) / max_val * 5
        i = int(math.floor(ratio))
        j = int(math.ceil(ratio))
        ratio = ratio - i
        r = (1 - ratio) * colors[i][c] + ratio * colors[j][c]
        return int(r * 255)

    width = img.shape[1]
    height = img.shape[0]
    for i in range(len(boxes)):
        box = boxes[i]
        x1 = int(box[0] * width)
        y1 = int(box[1] * height)
        x2 = int(box[2] * width)
        y2 = int(box[3] * height)

        if color:
            rgb = color
        else:
            rgb = (255, 0, 0)
        if len(box) >= 7 and class_names:
            cls_conf = box[5]
            cls_id = box[6]
            # print('%s: %f' % (class_names[cls_id], cls_conf))
            classes = len(class_names)
            offset = cls_id * 123457 % classes
            red = get_color(2, offset, classes)
            green = get_color(1, offset, classes)
            blue = get_color(0, offset, classes)
            if color is None:
                rgb = (red, green, blue)
            img = cv2.putText(img, class_names[int(cls_id)], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1.2, rgb, 2)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), rgb, 3)
    if savename:
        print("save plot results to %s" % savename)
        cv2.imwrite(savename, img)
    return img


def load_class_names(namesfile):
    class_names = []
    with open(namesfile, 'r') as fp:
        lines = fp.readlines()
    for line in lines:
        line = line.rstrip()
        class_names.append(line)
    return class_names


def bbox_iou1(box1, box2, x1y1x2y2=True):
    # print('iou box1:', box1)
    # print('iou box2:', box2)

    if x1y1x2y2:
        mx = min(box1[0], box2[0])
        Mx = max(box1[2], box2[2])
        my = min(box1[1], box2[1])
        My = max(box1[3], box2[3])
        w1 = box1[2] - box1[0]
        h1 = box1[3] - box1[1]
        w2 = box2[2] - box2[0]
        h2 = box2[3] - box2[1]
    else:
        w1 = box1[2]
        h1 = box1[3]
        w2 = box2[2]
        h2 = box2[3]

        mx = min(box1[0], box2[0])
        Mx = max(box1[0] + w1, box2[0] + w2)
        my = min(box1[1], box2[1])
        My = max(box1[1] + h1, box2[1] + h2)
    uw = Mx - mx
    uh = My - my
    cw = w1 + w2 - uw
    ch = h1 + h2 - uh
    carea = 0
    if cw <= 0 or ch <= 0:
        return 0.0

    area1 = w1 * h1
    area2 = w2 * h2
    carea = cw * ch
    uarea = area1 + area2 - carea
    return carea / uarea


def bbox_iou(box1, box2, x1y1x2y2=True):
    """
        计算IOU
    """
    if not x1y1x2y2:
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

    inter_rect_x1 = torch.max(b1_x1, b2_x1)
    inter_rect_y1 = torch.max(b1_y1, b2_y1)
    inter_rect_x2 = torch.min(b1_x2, b2_x2)
    inter_rect_y2 = torch.min(b1_y2, b2_y2)

    inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1e-3, min=0) * \
                 torch.clamp(inter_rect_y2 - inter_rect_y1 + 1e-3, min=0)

    b1_area = (b1_x2 - b1_x1 + 1e-3) * (b1_y2 - b1_y1 + 1e-3)
    b2_area = (b2_x2 - b2_x1 + 1e-3) * (b2_y2 - b2_y1 + 1e-3)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

    return iou


def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
    # 求左上角和右下角
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for image_i, image_pred in enumerate(prediction):
        # 利用物体置信度进行第一轮筛选
        conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()
        image_pred = image_pred[conf_mask]

        if not image_pred.size(0):
            continue

        # 获得种类及其置信度,获得分类置信度数值以及对应索引
        class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)

        # 获得的内容为(x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)

        # 获得种类
        unique_labels = detections[:, -1].cpu().unique()

        if prediction.is_cuda:
            unique_labels = unique_labels.cuda()

        for c in unique_labels:
            # 获得某一类初步筛选后全部的预测结果
            detections_class = detections[detections[:, -1] == c]
            # 按照存在物体的置信度排序
            _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)
            detections_class = detections_class[conf_sort_index]
            # 进行非极大抑制
            max_detections = []
            while detections_class.size(0):
                # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
                max_detections.append(detections_class[0].unsqueeze(0))
                if len(detections_class) == 1:
                    break
                ious = bbox_iou(max_detections[-1], detections_class[1:])
                detections_class = detections_class[1:][ious < nms_thres]

            # 堆叠
            max_detections = torch.cat(max_detections).data
            # Add max detections to outputs
            output[image_i] = max_detections if output[image_i] is None else torch.cat(
                (output[image_i], max_detections))
    return output

上一篇:学习模糊系统


下一篇:shell基本语法-变量