类激活热力图实例代码

使用mobilenetv3的类激活热力图(CAM,class activation map)

本人是个菜鸟,只能看看示例代码稍微改改,各位大神有好点子请在评论区留言
类激活热力图可以看出那些部分对分类的结果影响较大,一般使用卷积神经网络的最后一层网络的输出和其对应的梯度。需要新建一个网络并且求取梯度。
话不多说,上代码
代码中需要修改的就一个图片路径,其余可不动,有能力则自行修改

import torch.nn as nn
from nets.mobilenet_v3 import mobilenet_v3
from PIL import Image
import numpy as np
from torchvision import transforms
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
import torch
from torchvision import models

class MyMb(nn.Module):
    def __init__(self):
        super(MyMb, self).__init__()
        # 提取模型,不需要自己找权重,会自动下载
        self.model = models.mobilenet_v3_large(pretrained=True)
        self.features_conv = self.model.features[:16]  # 使用到16层
        self.max_pool = nn.MaxPool2d(2, stride=2)  # 这里是一些处理
        self.avgpool = self.model.avgpool  # 这也是
        self.classifier = nn.Sequential(  # 这个分类器中的参数可自定义
            nn.Linear(160*1*1, 80),
            nn.Hardswish(inplace=True),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(80, 1000)
        )
        self.gradients = None  # 用于生成梯度占位符

    def activations_hook(self, grad):
        # 获取梯度的钩子
        self.gradients = grad

    def forward(self, x):
        x = self.features_conv(x)
        h = x.register_hook(self.activations_hook)
        # 对主干网络的结果进行池化操作,可以注释掉
        x = self.max_pool(x)
        x = self.avgpool(x)
        # print(x.shape)
        x = x.view(1, -1)  # resize
        x = self.classifier(x)
        return x
	# 获取梯度
    def get_activations_gradient(self):
        return self.gradients
	# 获取主干网络的输出
    def get_activations(self, x):
        return self.features_conv(x)

# 此部分根据自己的图片路径修改
img = Image.open("img/street.jpg")
# 初始化网络
mymobile = MyMb()
mymobile.eval()
# 输入图片的预处理
data_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
input_im = data_transforms(img).unsqueeze(0)
# print(input_im)
# 输入到网络进行预测
im_pre = mymobile(input_im)
softmax = nn.Softmax(dim=1)
im_pre_prob = softmax(im_pre)
# 筛选出前5个得分高的
prob, prelab = torch.topk(im_pre_prob, 5)
prob = prob.data.numpy().flatten()
prelab = prelab.numpy().flatten()

# 下面计算所需的特征映射和梯度信息
im_pre[:, prelab[0]].backward()  # 获取相对于模型参数的梯度
gradients = mymobile.get_activations_gradient()  # 获取模型的梯度
mean_gradients = torch.mean(gradients, dim=[0, 2, 3])  # 计算梯度相应通道的均值
activations = mymobile.get_activations(input_im).detach()  # 获取输出的卷积特征
for i in range(len(mean_gradients)):
    # 每个通道乘以相应的均值
    activations[:,i,:,:] *= mean_gradients[i]
heatmap = torch.mean(activations, dim=1).squeeze()  # 计算所有通道的均值输出得到热力图
# 标准化操作
heatmap = F.relu(heatmap)
heatmap /= torch.max(heatmap)
heatmap = heatmap.numpy()
plt.matshow(heatmap)  # 可视化热力图

# 将原图像与热力图像融合
img = cv2.imread("img/street.jpg")
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255*heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
Grad_cam_img = heatmap * 0.4 +img
Grad_cam_img = Grad_cam_img / Grad_cam_img.max()

# 可视化
b, g, r = cv2.split(Grad_cam_img)
Grad_cam_img = cv2.merge([r, g, b])
plt.figure()
plt.imshow(Grad_cam_img)
plt.show()

本文使用的是现成的主干网络,大家可以开发一下,用于自己搭建的主干网络。
虽然说这个代码可以正确生成热力图,不过我不知道为什么每次生成的都不大一样,但是有些图是类似的。
最后生成的图片
类激活热力图实例代码

类激活热力图实例代码

上一篇:基于mmdetection的热力图绘制


下一篇:用关键点进行目标检测