Pytorch提取AI模型的中间结果的方法

算法和数据的相互成全

 数据中,结构化数据的应用和管理相对简单,然而随着非结构化数据的大量涌现,其处理方式和传统的结构化数据有所不同。
 其一:处理工具智能化,智能化一方面体现再AI模型的应用,另外一方面也可以说数据具有了独特的情况,可随着模型的不同,数据就有所不同
 其二,随着模型的固化,其实也是一种智力方式的固化,不同的模型对数据的反应和推断会不同,在经过各种模型的打磨下,数据将做为一种全新的概念
     数据不仅仅是记录,也是改变。数据不仅仅是资源,也是一种生产资料。

数据应用算法的主要方面

增加对非结构化数据的了解
增加需要迭代和低算力模型的改进
对数据进行管理

学习方法

做数据处理相关工作的技能,在处理算法上有所不同    
  数据了解算法的学术类做法
      了解其基本原理
  	写简单的模型,不断的迭代模型
  
  数据了解算法的工程化阶段    
      1.利用算法提供的模型进行推断 infer
      2.写代码提取模型的部分结果
      3.训练并优化模型
      4.复现和实现模型

工程做法

 利用代码提取模型的部分结果
  目前主流上有三种做法
    第一种,是写一个模型,重写这个模型中的部分组件,这个组件执行相同的功能,却只返回自己需要的输出
	第二种,只执行模型的部分,按照通常做法创建模型,但使用自己的代码去替代原模型的forward()
	第三种,就是使用forward hooks

代码示例

import torch
import torch.cuda
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import cv2

def get_model():

    #model_path ="/home/test/soft/models/vgg16-397923af.pth"
    model_path ="/home/test/soft/models/resnet101-5d3b4d8f.pth"
    pre = torch.load(model_path)
    # 加载模型 
    #model_ft =  models.vgg16(pretrained=False)
    model_ft = models.resnet101(pretrained=False)
    model_ft.load_state_dict(pre)
    model_ft.cuda()
    return model_ft

    # # 查看模型结构
    # print(model_ft)
    # # 查看网络参数
    # for name, parameters in model_ft.named_parameters():
    #     print(name, ':', parameters.size())
    # # 网络模型的卷积方式以及权重数值
    # print("#############-parameters")
    # for child in model_ft.children():
    #     print(child) 
    #     # for param in child.parameters():
    #     #     print(param)

def deal_img(img_path):
    """Transforming images on GPU  单个图像导入"""
    image = cv2.imread(img_path) 
    image_new =  cv2.resize(image, (224,224))
    my_transforms= transforms.Compose(
        [ 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229,0.224,0.225]) 
        ]
        )
    my_tensor = my_transforms(image_new)
    my_tensor = my_tensor.resize_(1,3,224,224)
    my_tensor= my_tensor.cuda()
    return my_tensor

def cls_inference(cls_model,imgpth):
    input_tensor = deal_img(imgpth)
    cls_model.eval()
    result = cls_model(input_tensor)
    result_npy = result.data.cpu().numpy()
    max_index = np.argmax(result_npy[0])
    return max_index

# 方式一
def feature_extract(cls_model,imgpth):
    cls_model.fc = torch.nn.LeakyReLU(0.1)
    cls_model.eval()
    input_tensor = deal_img(imgpth)
    result = cls_model(input_tensor)
    result_npy = result.data.cpu().numpy()
    return result_npy[0]

#方法二 
# 定义一个特征提取的类 resnet
#  提取特征层
class Feature_extractor(nn.Module):
    #"""从一个已经搭建好的网络中方便地提取到某些层的输出"""
    def __init__(self, submodule, extract_layers):
       super(Feature_extractor,self).__init__()
       self.submodule= submodule
       self.extract_layers = extract_layers
    #"针对该模型进行了修改-自定义了forward函数,选择在哪一层提取特征"
    #  forward函数针对Resnet模型进行了修改 
    def forward(self, input):
        outputs = []
        for name,module in self.submodule._modules.items():
            if name is "fc":
                input= input.view(input.size(0),-1)
            input = module(input)
            if name in self.extract_layers  and "fc" not in name:
                outputs.append(input)
        return outputs
		
def resnet_feature_extract(feature_model,imgpth):
    feature_model.eval()
    input_tensor = deal_img(imgpth)
    result = feature_model(input_tensor)
    return result[0]


#方法三 hook输出conv层的特征图


## 主函数
if __name__ == "__main__":
    image_path="/home/test/soft/8.jpg"
    # 建立模型
    model = get_model()
    cls_label = cls_inference(model,image_path)
    print(cls_label)
	#
    #feature = feature_extract(model,image_path)
    #print(feature)
    print(model)
	###方式二
    # layers you want to extract`
    print("\n######################################\n")
    target_layers = [ "layer1", "conv1"] 
    result_feature = Feature_extractor(model,target_layers)
    feature_ge = resnet_feature_extract(result_feature,image_path)
    print(feature_ge)

参考:

  Use Models  https://detectron2.readthedocs.io/en/latest/tutorials/models.html
上一篇:利用VS2008发布一个简单的webservice


下一篇:Pytest单元测试框架之parametrize参数化