18hook函数与CAM可视化

一、Hook函数概念

1.1 Hook引入的原因

Pytorch的运行机制是动态计算图,动态图运算结束后,一些中间变量(如feature map和非叶子结点的梯度)会被释放掉,但是往往有时候我们需要获取这些中间变量,这时就可以通过Hook函数在主体中根据Hook机制添加额外的函数来获取或改变中间变量

1.2 Hook函数机制

Hook函数机制: 不改变主体(前向传播和后向传播),实现额外功能,像一个挂件,挂钩, hook

18hook函数与CAM可视化
nn.module中的call()函数的运行机制也正是hook函数机制,整个call函数分为四个部分,分别是:

  • forward_pre_hook
  • forward
  • forward_hook
  • backward_hook

如上图所示,call()函数执行forward_pre_hook函数,然后执行forward前向传播过程,接着执行forward_hook函数,最后执行back_forward函数
所以,在前向传播过程中,不仅仅只是单纯地执行前项传播,而是会提供hook函数接口,来实现额外的操作和功能

1.3 四种hook函数

主要分为三类:针对tensor的,前向传播的,和后向传播的

  1. torch.Tensor.register_hook(hook)
  2. torch.nn.Module.register_forward_hook
  3. torch.nn.Module.register_forward_pre_hook
  4. torch.nn.Module.register_backward hook

二、Hook函数与特征提取

2.1 Tensor.register_hook

hook(grad)

功能: 注册一个反向传播hook函数

Hook函数仅一个输入参数,为张量的梯度,返回张量或者无返回

示例:通过hook函数获取和改变非叶子结点的梯度
18hook函数与CAM可视化

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()        # 存储张量的梯度

    def grad_hook(grad):
        a_grad.append(grad)

    handle = a.register_hook(grad_hook) # 把定义的函数注册到对应的张量上

    y.backward()

    # 查看梯度
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]: ", a_grad[0])
    handle.remove()


# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()

    def grad_hook(grad):               # 定义hook函数修改张量梯度
        grad *= 2
        return grad*3                  # 通过return返回的梯度会覆盖掉原梯度

    handle = w.register_hook(grad_hook)

    y.backward()

    # 查看梯度
    print("w.grad: ", w.grad)
    handle.remove()

运行结果:
18hook函数与CAM可视化
说明:

  • 通过定义list,通过hook函数将张量的梯度保存到了list中,从而运算结束后能获取梯度
  • 自定义的hook函数return返回梯度,则返回的梯度会覆盖原梯度

2.2 Module.register_forward _hook

hook(module, input, output)

功能: 注册module的前向传播hook函数

返回值:无

参数:

  • module:当前网络层
  • input:当前网络层输入数据
  • output :当前网络层输出数据

示例:获取前向传播中的特征图
18hook函数与CAM可视化

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子

# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 注册hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    # net.conv1.register_forward_pre_hook(forward_pre_hook)
    # net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    # loss_fnc = nn.L1Loss()
    # target = torch.randn_like(output)
    # loss = loss_fnc(target, output)
    # loss.backward()

    # 观察
    print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

18hook函数与CAM可视化

2.3 Module.register_forward_pre_hook

hook(module, input)

功能: 注册module前向传播的hook函数

返回值:无

参数:

  • module: 当前网络层
  • input: 当前网络层输入数据

2.4 Module.register_backward_hook

hook(module, grad_input, grad_output)

功能: 注册module反向传播的hook函数

返回值:tensor or None

参数:

  • module: 当前网络层
  • grad_input: 当前网络层输入梯度数据
  • grad_output: 当前网络层输出梯度数据

示例:forward_pre_hook和backward_hook的使用

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 注册hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    net.conv1.register_forward_pre_hook(forward_pre_hook)
    net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()

18hook函数与CAM可视化

2.5 采用hook函数可视化特征图

# -*- coding:utf-8 -*-
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools import set_seed
import torchvision.models as models

set_seed(1)  # 设置随机种子

# ----------------------------------- feature map visualization -----------------------------------
# flag = 0
flag = 1
if flag:
    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

    # 数据
    path_img = "./lena.png"     # your path to image
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]

    norm_transform = transforms.Normalize(normMean, normStd)
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm_transform
    ])

    img_pil = Image.open(path_img).convert('RGB')
    if img_transforms is not None:
        img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)    # chw --> bchw

    # 模型
    alexnet = models.alexnet(pretrained=True)

    # 注册hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():  # named_modules()返回网络的子网络层及其名称

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())

            n1, n2 = name.split(".")

            def hook_func(m, i, o):
                key_name = str(m.weight.shape)
                fmap_dict[key_name].append(o)

            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():
        fmap = fmap_list[0]
        fmap.transpose_(0, 1)

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)



18hook函数与CAM可视化

三、CAM and Grad-CAM

3.1 CAM

CAM: 类激活图,Class activation map

功能:根据网络的输出,分析网络关注于图像的那一部分而到的该输出

基本思想:对网络输出的最后一个特征图进行加权求和平均,得到一个注意力机制
18hook函数与CAM可视化
运行机制:对最后的特征图进行全局平均池化,将特征图转换成向量的形式,每个channel对应一个神经元,然后再接一个fc层进行输出,而图像输出的类对应的神经元的权重就是特征图的权重

缺点:需要改动网络模型,网络最后的输出必须要进行全局平均池化再去获取权重,因为往往还需要改动后面的网络层再重新去训练

CAM: 《Learning Deep Features for Discriminative Localization》

3.2 Grad-CAM

Grad-CAM: CAM改进版,利用梯度作为特征图权重
18hook函数与CAM可视化
运行机制:根据输出向量,进行backward,求取特征图的梯度,得到每个特征图上每个像素点对应的梯度,也就是特征图对应的梯度图,然后再对每个梯度图求平均,这个平均值就对应于每个特征图的权重,然后再将权重与特征图进行加权求和,最后经过relu激活函数就可以得到最终的类激活图

Grad-CAM: 《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》

实验示例:
18hook函数与CAM可视化
分析与代码:https://zhuanlan.zhihu.com/p/75894080

18hook函数与CAM可视化18hook函数与CAM可视化 G5Lorenzo 发布了88 篇原创文章 · 获赞 9 · 访问量 7106 私信 关注
上一篇:TorchScript的TracedModule和ScriptModule的区别


下一篇:C++11——转移和完美转发