pytorch0-梯度取反

借助torch.autograd中的Function

import torch
from torch.autograd import Function
import torch.nn as nn


class ReverseLayer(Function):
    @staticmethod
    def forward(ctx, x):
        return x
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.parameter1 = nn.Parameter(torch.ones(10, 10))
        self.parameter2 = nn.Parameter(torch.ones(10, 10))
        self.parameter3 = nn.Parameter(torch.ones(10, 10))
    def forward(self, x):
        return x@self.parameter1@self.parameter2@self.parameter3


class ReverseNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.parameter1 = nn.Parameter(torch.ones(10, 10))
        self.parameter2 = nn.Parameter(torch.ones(10, 10))
        self.parameter3 = nn.Parameter(torch.ones(10, 10))
    def forward(self, x):
        x1 = x@self.parameter1
        x2 = ReverseLayer.apply(x1@self.parameter2)
        return x2@self.parameter3


dataInput = torch.randn(2, 10)
dataTarget = torch.randn(2, 10)

net1 = Net()
net2 = ReverseNet()
loss1 = torch.mean(net1(dataInput) - dataTarget)
loss1.backward()
loss2 = torch.mean(net2(dataInput) - dataTarget)
loss2.backward()
print('=======================PARAMETER1============================')
print(net1.parameter1.grad[0])
print(net2.parameter1.grad[0])
print('=======================PARAMETER2============================')
print(net1.parameter2.grad[0])
print(net2.parameter2.grad[0])
print('=======================PARAMETER3============================')
print(net1.parameter3.grad[0])
print(net2.parameter3.grad[0])

'''
It can be seen that due to the chain rule, 
the derivative of all the layers before the reverse layer is taken to be negative 
'''
上一篇:PyTorch中的拷贝与就地操作详解


下一篇:【忆臻解读】Andrej Karpathy提到的神经网络六大坑