借助torch.autograd中的Function
import torch
from torch.autograd import Function
import torch.nn as nn
class ReverseLayer(Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.parameter1 = nn.Parameter(torch.ones(10, 10))
self.parameter2 = nn.Parameter(torch.ones(10, 10))
self.parameter3 = nn.Parameter(torch.ones(10, 10))
def forward(self, x):
return x@self.parameter1@self.parameter2@self.parameter3
class ReverseNet(nn.Module):
def __init__(self):
super().__init__()
self.parameter1 = nn.Parameter(torch.ones(10, 10))
self.parameter2 = nn.Parameter(torch.ones(10, 10))
self.parameter3 = nn.Parameter(torch.ones(10, 10))
def forward(self, x):
x1 = x@self.parameter1
x2 = ReverseLayer.apply(x1@self.parameter2)
return x2@self.parameter3
dataInput = torch.randn(2, 10)
dataTarget = torch.randn(2, 10)
net1 = Net()
net2 = ReverseNet()
loss1 = torch.mean(net1(dataInput) - dataTarget)
loss1.backward()
loss2 = torch.mean(net2(dataInput) - dataTarget)
loss2.backward()
print('=======================PARAMETER1============================')
print(net1.parameter1.grad[0])
print(net2.parameter1.grad[0])
print('=======================PARAMETER2============================')
print(net1.parameter2.grad[0])
print(net2.parameter2.grad[0])
print('=======================PARAMETER3============================')
print(net1.parameter3.grad[0])
print(net2.parameter3.grad[0])
'''
It can be seen that due to the chain rule,
the derivative of all the layers before the reverse layer is taken to be negative
'''