如何生成锚框并在图片可视化

import torch
from torch._C import Size
from d2l import torch as d2l

torch.set_printoptions(2) #精简打印

def multibox_prior(data,sizes,ratios):
    #生成以每个像素为中心具有不同形状的锚框
    in_height,int_width = data.shape[-2:]
    device,num_sizes,num_ratios = data.device,len(sizes),len(ratios)
    boxes_per_pixel = (num_sizes+num_ratios-1)
    size_tensor = torch.tensor(sizes,device=device)
    ratio_tensor = torch.tensor(ratios,device=device)

    #为了将锚点移动到像素中,需要设置偏移量
    #每个像素的高为1宽为1,选择偏移我们的中心0.5
    offset_h,offset_w = 0.5,0.5
    steps_h = 1.0/in_height
    steps_w = 1.0/int_width
    #生成锚框的所有中心点
    center_h = (torch.arange(in_height,device=device)+offset_h)*steps_h
    center_w = (torch.arange(int_width,device=device)+offset_w)*steps_w
    shift_y,shift_x  = torch.meshgrid(center_h,center_w)
    shift_y,shift_x = shift_y.reshape(-1),shift_x.reshape(-1)

    #生成“boxes_per_pixel"个高和宽
    #之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)

    w = torch.cat((size_tensor*torch.sqrt(ratio_tensor[0]),
    sizes[0]*torch.sqrt(ratio_tensor[1:])))\
        *in_height/int_width
    h = torch.cat((size_tensor/torch.sqrt(ratio_tensor[0]),
    sizes[0]/torch.sqrt(ratio_tensor[1:])))

    # 除以2 来获得半宽和半高
    anchor_manipulations = torch.stack((-w,-h,w,h)).T.repeat(
        in_height*int_width,1)/2
    
    #每个中心点都将有“boxes_per_pixel"个锚框
    #所以生成含有锚框中心的网格,重复了“boxes_per_pixel"次
    out_grid = torch.stack([shift_x,shift_y,shift_x,shift_y],
    dim=1).repeat_interleave(boxes_per_pixel,dim=0)
    output = out_grid+anchor_manipulations
    return output.unsqueeze(0)

img = d2l.plt.imread('../img/catdog.jpg')
h,w = img.shape[:2]

print(h,w)
X = torch.rand(size=(1,3,h,w))
Y = multibox_prior(X,sizes = [0.75,0.5,0.25],ratios=[1,2,0.5])
Y.shape()

boxes = Y.reshape(h,w,5,4)
boxes[250,250,0,:]

#显示所有边界框
def show_bboxes(axes,bboxes,labels=None,colors=None):
    def _make_list(obj,default_values=None):
        if obj is None:
            obj = default_values
        elif not isinstance(obj,(list,tuple)):
            obj=[obj]
        return obj
    labels = _make_list(labels)
    colors = _make_list(colors,['b','g','r','m','c'])
    for i , bbox in enumerate(bboxes):
        color = colors[i%len(colors)]
        rect = d2l.bbox_to_rect(bbox.detach().numpy(),color)
        axes.add_patch(rect)
        if labels and len(labels)>i:
            text_color = 'k' if color =='w' else 'w'
            axes.text(rect.xy[0],rect.xy[1],labels[i],
            va = 'center',ha = 'center',fontsize = 9,color = text_color,
            bbox = dict(facecolor=color,lw =0))
#原链接:https://zh-v2.d2l.ai/chapter_preface/index.html
上一篇:LDO的原理及应用


下一篇:(LDO)MX78LXXPlasticEncapsulate 稳压器