import torch
from torch._C import Size
from d2l import torch as d2l
torch.set_printoptions(2) #精简打印
def multibox_prior(data,sizes,ratios):
#生成以每个像素为中心具有不同形状的锚框
in_height,int_width = data.shape[-2:]
device,num_sizes,num_ratios = data.device,len(sizes),len(ratios)
boxes_per_pixel = (num_sizes+num_ratios-1)
size_tensor = torch.tensor(sizes,device=device)
ratio_tensor = torch.tensor(ratios,device=device)
#为了将锚点移动到像素中,需要设置偏移量
#每个像素的高为1宽为1,选择偏移我们的中心0.5
offset_h,offset_w = 0.5,0.5
steps_h = 1.0/in_height
steps_w = 1.0/int_width
#生成锚框的所有中心点
center_h = (torch.arange(in_height,device=device)+offset_h)*steps_h
center_w = (torch.arange(int_width,device=device)+offset_w)*steps_w
shift_y,shift_x = torch.meshgrid(center_h,center_w)
shift_y,shift_x = shift_y.reshape(-1),shift_x.reshape(-1)
#生成“boxes_per_pixel"个高和宽
#之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)
w = torch.cat((size_tensor*torch.sqrt(ratio_tensor[0]),
sizes[0]*torch.sqrt(ratio_tensor[1:])))\
*in_height/int_width
h = torch.cat((size_tensor/torch.sqrt(ratio_tensor[0]),
sizes[0]/torch.sqrt(ratio_tensor[1:])))
# 除以2 来获得半宽和半高
anchor_manipulations = torch.stack((-w,-h,w,h)).T.repeat(
in_height*int_width,1)/2
#每个中心点都将有“boxes_per_pixel"个锚框
#所以生成含有锚框中心的网格,重复了“boxes_per_pixel"次
out_grid = torch.stack([shift_x,shift_y,shift_x,shift_y],
dim=1).repeat_interleave(boxes_per_pixel,dim=0)
output = out_grid+anchor_manipulations
return output.unsqueeze(0)
img = d2l.plt.imread('../img/catdog.jpg')
h,w = img.shape[:2]
print(h,w)
X = torch.rand(size=(1,3,h,w))
Y = multibox_prior(X,sizes = [0.75,0.5,0.25],ratios=[1,2,0.5])
Y.shape()
boxes = Y.reshape(h,w,5,4)
boxes[250,250,0,:]
#显示所有边界框
def show_bboxes(axes,bboxes,labels=None,colors=None):
def _make_list(obj,default_values=None):
if obj is None:
obj = default_values
elif not isinstance(obj,(list,tuple)):
obj=[obj]
return obj
labels = _make_list(labels)
colors = _make_list(colors,['b','g','r','m','c'])
for i , bbox in enumerate(bboxes):
color = colors[i%len(colors)]
rect = d2l.bbox_to_rect(bbox.detach().numpy(),color)
axes.add_patch(rect)
if labels and len(labels)>i:
text_color = 'k' if color =='w' else 'w'
axes.text(rect.xy[0],rect.xy[1],labels[i],
va = 'center',ha = 'center',fontsize = 9,color = text_color,
bbox = dict(facecolor=color,lw =0))
#原链接:https://zh-v2.d2l.ai/chapter_preface/index.html