Person_reID_baseline_pytorch 源码解析之 test.py

源码中有两个用于测试的脚本: test.py 和 evaluate_gpu.py 。其中, test.py 加载通过脚本 train.py 训练好的模型,实现对 query 和 gallery 图片的特征提取;本文对脚本 test.py 进行解析。

1. 加载模型和数据

首先需要载入训练好的模型,这里以基于 Resnet50 输出类别为 751 类的行人重识别模型 ft_net 为例。

model_structure = ft_net(751)
model = load_network(model_structure)

然后需要载入经过预处理的 gallery 和 query 数据集

data_transforms = transforms.Compose([
        transforms.Resize((256,128), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=False, num_workers=0) for x in ['gallery','query']}

加载预处理过的数据集和训练好的模型,然后使用函数 extract_feature 进行特征提取

with torch.no_grad():
    gallery_feature = extract_feature(model,dataloaders['gallery'])
    query_feature = extract_feature(model,dataloaders['query'])

2. 完成特征提取

extract_feature 是 test.py 中非常重要的一个函数,用于提取图片的特征,下面对它逐行解析

def extract_feature(model,dataloaders):
    features = torch.FloatTensor()
    count = 0
    # 加载数据集
    for data in dataloaders:
        img, label = data
        n, c, h, w = img.size()
        count += n
        # 统计数据集图片数量
        print(count)
        ff = torch.FloatTensor(n,512).zero_().cuda()
        for i in range(2):
            if(i==1):
            	# 翻转图片
                img = fliplr(img)
            # 将图片变成 Variable,准备加载到网络中
            input_img = Variable(img.cuda())
            # 缩放尺寸 multiple_scale
            for scale in ms:
                if scale != 1:
                    # bicubic is only  available in pytorch>= 1.1
                    input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic', align_corners=False)
                # 模型推理
                outputs = model(input_img) 
                # 拼接多尺度预测结果
                ff += outputs
        # norm feature 特征归一化
            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))
		# 返回提取到的特征
        features = torch.cat((features,ff.data.cpu()), 0)
    return features

3. 实现特征归一化

fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)

这里是在输入张量 ff 的第 1 维进行 L2-norm,即 2 范数归一化。特征向量中每个元素均除以向量的L2范数。

Person_reID_baseline_pytorch 源码解析之 test.py
pytorch 中使用 torch.norm 计算张量的范数。

fnorm = torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
  • input 输入张量
  • p 是范数计算中的幂指数值,p = 2 时即为 2 范数
  • dim 指定计算的维度,如果 dim 是整数值,则计算向量范数。当输入张量 input 超过2维,将在最后一维计算向量范数
  • keepdim 指明是否保留输出张量的维度dim
  • out 输出张量
  • dtype 返回张量的期待数据类型

令特征向量除以向量的L2范数,expand_as 函数将范数 fnorm 扩展成张量 ff 相同的维度。

 ff = ff.div(fnorm.expand_as(ff))

然后使用 tensor.div 完成除法。

Tensor.div(value, *, rounding_mode=None)

最后,使用 torch.cat 在第 0 维上拼接输入张量

features = torch.cat((features,ff.data.cpu()), 0)

4. 生成 Matlab 文件

通过上述步骤实现了 query 和 gallery 图片特征的提取,将特征矩阵存储到 pytorch_result.mat 文件中。

# Save to Matlab for check
result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
scipy.io.savemat('pytorch_result.mat',result)

为了评估模型效果,还要记录图片的 label 和 camera 。
这里使用 get_id 函数通过图片名称获取 label 和 camera 信息。

def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path)
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels

gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)

生成的 Matlab 文件将被脚本 evaluate_gpu.py 使用,用于计算模型的评估指标。

参考链接

  1. pytorch求范数函数——torch.norm
  2. pytorch torch.norm 文档
  3. Pytorch expand_as()函数
  4. torch.cat()函数的官方解释,详解以及例子
  5. torch.stack()的官方解释,详解以及例子
上一篇:POI导出EXCEL经典实现


下一篇:Java POI 导出EXCEL经典实现 Java导出Excel