Pytorch下实现自定义激活函数

(每个人错误情况可能不一样,仅供参考。以复现一篇含自定义激活函数论文为例。)
1:拟设计激活函数:
Pytorch下实现自定义激活函数
2:代码编辑:

class log_act(nn.Module):
    def __init__(self, alpha, beta, positive_flag=True):
        super(log_act, self).__init__()
        self.positive_flag = positive_flag
        self.a = alpha
        self.C1 = beta
    def forward(self, x):
        x = x.detach().numpy()
        C0 = np.exp(-1)
        if self.positive_flag:
            alpha = self.a
        else:
            alpha = -self.a
        out = alpha * np.log(np.greater(x, 0)+C0) + self.C1
        out = torch.tensor(out)
        return out
        ```
#设置自定义输出值条件定义:
class log_act_helper(nn.Module):
    def __init__(self, alpha, beta):
        super(log_act_helper, self).__init__()
        self.a = alpha
        self.C1 = beta
    def forward(self, x):
        x0 = x.clone().detach()
        f1 = log_act(alpha=self.a, beta=self.C1, positive_flag=True)
        out1 = f1(x0)
        f2 = log_act(alpha=self.a, beta=self.C1, positive_flag=False)
        out2 = f2(x0)
        out = torch.where(np.greater(x0, 0), out1, out2)
        # out = torch.tensor(out)
        return out

激活函数设计完毕。
问题1:Pytorch下实现自定义激活函数
查找资料:带转换PyTorch张量带有梯度,直接转换为numpy数据会破坏梯度图。转换数据不需要保留梯度信息,x=x.detach().numpy() (不适合我的错误,尝试之后未解决)
**解决方案:**对卷积之后得到张量x进行克隆,传入激活函数计算:x=x.clone().detach()或者使用x=x.clone().detach().requires_grad_(True)其中clone()复制张量,梯度流仍流向原来的张量。detach()张量脱离计算图,不牵扯梯度计算。requires_grad_是否需要梯度。
问题2:
Pytorch下实现自定义激活函数
**解决方案:**问题是:日志中遇到了无效值。log函数图像如下:Pytorch下实现自定义激活函数

基于log(对数)函数log(A) 中当A<0时,对数函数不成立,必须取A>0,更改为:out = alpha * (np.log(np.greater(x, 0)+C0)) + 1 比较x与0大小,确保A取正值。

3:测试:

if __name__=="__main__":
    x = torch.rand(3, 3, 448, 448)
    conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                           bias=False)
    x = conv1(x)
    bn1 = nn.BatchNorm2d(64)
    x = bn1(x)
    activation = log_act_helper(alpha=0.2, beta=1)
    out = activation(x)
    print(out.shape)

输出:Pytorch下实现自定义激活函数
以上是激活函数不需要梯度情况。需要梯度情况,需要重定义forward和backward参考: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html添加链接描述

上一篇:oracle数据泵之解决方案(用户)导入导出。


下一篇:云锵投资 2021 年 12 月简报