一、Hook函数概念
1.1 Hook引入的原因
Pytorch的运行机制是动态计算图,动态图运算结束后,一些中间变量(如feature map和非叶子结点的梯度)会被释放掉,但是往往有时候我们需要获取这些中间变量,这时就可以通过Hook函数在主体中根据Hook机制添加额外的函数来获取或改变中间变量
1.2 Hook函数机制
Hook函数机制: 不改变主体(前向传播和后向传播),实现额外功能,像一个挂件,挂钩, hook
nn.module中的call()函数的运行机制也正是hook函数机制,整个call函数分为四个部分,分别是:
- forward_pre_hook
- forward
- forward_hook
- backward_hook
如上图所示,call()函数执行forward_pre_hook函数,然后执行forward前向传播过程,接着执行forward_hook函数,最后执行back_forward函数
所以,在前向传播过程中,不仅仅只是单纯地执行前项传播,而是会提供hook函数接口,来实现额外的操作和功能
1.3 四种hook函数
主要分为三类:针对tensor的,前向传播的,和后向传播的
- torch.Tensor.register_hook(hook)
- torch.nn.Module.register_forward_hook
- torch.nn.Module.register_forward_pre_hook
- torch.nn.Module.register_backward hook
二、Hook函数与特征提取
2.1 Tensor.register_hook
hook(grad)
功能: 注册一个反向传播hook函数
Hook函数仅一个输入参数,为张量的梯度,返回张量或者无返回
示例:通过hook函数获取和改变非叶子结点的梯度
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list() # 存储张量的梯度
def grad_hook(grad):
a_grad.append(grad)
handle = a.register_hook(grad_hook) # 把定义的函数注册到对应的张量上
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad): # 定义hook函数修改张量梯度
grad *= 2
return grad*3 # 通过return返回的梯度会覆盖掉原梯度
handle = w.register_hook(grad_hook)
y.backward()
# 查看梯度
print("w.grad: ", w.grad)
handle.remove()
运行结果:
说明:
- 通过定义list,通过hook函数将张量的梯度保存到了list中,从而运算结束后能获取梯度
- 自定义的hook函数return返回梯度,则返回的梯度会覆盖原梯度
2.2 Module.register_forward _hook
hook(module, input, output)
功能: 注册module的前向传播hook函数
返回值:无
参数:
- module:当前网络层
- input:当前网络层输入数据
- output :当前网络层输出数据
示例:获取前向传播中的特征图
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
# net.conv1.register_forward_pre_hook(forward_pre_hook)
# net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
# loss_fnc = nn.L1Loss()
# target = torch.randn_like(output)
# loss = loss_fnc(target, output)
# loss.backward()
# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
2.3 Module.register_forward_pre_hook
hook(module, input)
功能: 注册module前向传播前的hook函数
返回值:无
参数:
- module: 当前网络层
- input: 当前网络层输入数据
2.4 Module.register_backward_hook
hook(module, grad_input, grad_output)
功能: 注册module反向传播的hook函数
返回值:tensor or None
参数:
- module: 当前网络层
- grad_input: 当前网络层输入梯度数据
- grad_output: 当前网络层输出梯度数据
示例:forward_pre_hook和backward_hook的使用
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
2.5 采用hook函数可视化特征图
# -*- coding:utf-8 -*-
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools import set_seed
import torchvision.models as models
set_seed(1) # 设置随机种子
# ----------------------------------- feature map visualization -----------------------------------
# flag = 0
flag = 1
if flag:
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
# 数据
path_img = "./lena.png" # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
norm_transform
])
img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
img_tensor = img_transforms(img_pil)
img_tensor.unsqueeze_(0) # chw --> bchw
# 模型
alexnet = models.alexnet(pretrained=True)
# 注册hook
fmap_dict = dict()
for name, sub_module in alexnet.named_modules(): # named_modules()返回网络的子网络层及其名称
if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list())
n1, n2 = name.split(".")
def hook_func(m, i, o):
key_name = str(m.weight.shape)
fmap_dict[key_name].append(o)
alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)
# forward
output = alexnet(img_tensor)
# add image
for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]
fmap.transpose_(0, 1)
nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
三、CAM and Grad-CAM
3.1 CAM
CAM: 类激活图,Class activation map
功能:根据网络的输出,分析网络关注于图像的那一部分而到的该输出
基本思想:对网络输出的最后一个特征图进行加权求和平均,得到一个注意力机制
运行机制:对最后的特征图进行全局平均池化,将特征图转换成向量的形式,每个channel对应一个神经元,然后再接一个fc层进行输出,而图像输出的类对应的神经元的权重就是特征图的权重
缺点:需要改动网络模型,网络最后的输出必须要进行全局平均池化再去获取权重,因为往往还需要改动后面的网络层再重新去训练
CAM: 《Learning Deep Features for Discriminative Localization》
3.2 Grad-CAM
Grad-CAM: CAM改进版,利用梯度作为特征图权重
运行机制:根据输出向量,进行backward,求取特征图的梯度,得到每个特征图上每个像素点对应的梯度,也就是特征图对应的梯度图,然后再对每个梯度图求平均,这个平均值就对应于每个特征图的权重,然后再将权重与特征图进行加权求和,最后经过relu激活函数就可以得到最终的类激活图
Grad-CAM: 《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》
实验示例:
分析与代码:https://zhuanlan.zhihu.com/p/75894080