参考的原文章
知乎上对应的翻译
初始化是为了防止梯度消失和爆炸
编写代码,假设输入是512的行向量,经过10个512x512的矩阵,计算输出的平均值和标准差。
输入行向量,每个矩阵,都是标准正态分布
import torch
# 随机生成一个512的输入值,服从正态分布
x = torch.randn(512)
y = x
for i in range(10):
a = torch.randn(512, 512)
y = a @ y
print(y.mean())
print(y.std())
运行结果:
改成100直接算不出来了
可以看到,平均值和标准差都非常大,因为假设是1x512的x和512x1的y做矩阵乘法,其中x和y都是标准正态分布,并且相互独立,则两个相乘是第二类修正贝塞尔函数
具体推导的知乎链接
因为数学太菜和编程功底不行,尝试了公式推导,python暴力算,以及蒙特卡洛枚举,都没成功。。。。太弱小了。
但是可以通过修改之前的代码,把每一步的矩阵的数绘制出直方图,便可以很直观地看到梯度爆炸。主要原因是方差变大了,虽然均值是0,但是分布会很广。
梯度爆炸的代码
import matplotlib.pyplot as plt
import torch
# 随机生成一个512的输入值,服从正态分布
x = torch.randn(512)
y = x
for i in range(1, 9):
a = torch.randn(512, 512)
y = a @ y
plt.subplot(2, 4, i)
plt.hist(y)
print(i, y.mean())
print(i, y.std())
plt.show()
梯度消失,即在输出的地方乘0.01,即代码如下:
import matplotlib.pyplot as plt
import torch
# 随机生成一个512的输入值,服从正态分布
x = torch.randn(512)*0.01
y = x
for i in range(1, 9):
a = torch.randn(512, 512)*0.01
y = a @ y
plt.subplot(2, 4, i)
plt.hist(y)
print(i, y.mean())
print(i, y.std())
plt.show()
运行结果如下:
正态分布的表达式为
y
=
1
σ
2
π
exp
−
x
2
2
σ
2
y=\frac{1}{\sigma\sqrt{2\pi}}\exp^{-\frac{x^2}{2\sigma^2}}
y=σ2π
1exp−2σ2x2
D
=
σ
2
D=\sigma^2
D=σ2
当
σ
1
=
0.1
σ
\sigma_1=0.1\sigma
σ1=0.1σ
y
1
=
1
σ
1
2
π
exp
−
x
2
2
σ
1
2
=
1
0.1
∗
σ
2
π
exp
−
x
2
2
σ
2
∗
0.01
\begin{aligned} y_1 &=\frac{1}{\sigma_1\sqrt{2\pi}}\exp^{-\frac{x^2}{2\sigma_1^2}}\\ &=\frac{1}{0.1*\sigma\sqrt{2\pi}}\exp^{-\frac{x^2}{2\sigma^2*0.01}} \end{aligned}
y1=σ12π
1exp−2σ12x2=0.1∗σ2π
1exp−2σ2∗0.01x2
D
=
0.01
σ
2
D=0.01\sigma^2
D=0.01σ2
而当
x
2
=
0.01
x
x_2 = 0.01x
x2=0.01x
即
100
x
2
=
x
100x_2=x
100x2=x
y
2
=
1
σ
2
π
exp
−
100
x
2
2
σ
2
=
1
σ
2
π
exp
−
x
2
2
σ
2
∗
0.01
\begin{aligned} y_2 &=\frac{1}{\sigma\sqrt{2\pi}}\exp^{-\frac{100x^2}{2\sigma^2}}\\ &=\frac{1}{\sigma\sqrt{2\pi}}\exp^{-\frac{x^2}{2\sigma^2*0.01}} \end{aligned}
y2=σ2π
1exp−2σ2100x2=σ2π
1exp−2σ2∗0.01x2
比较上式,有
10
y
1
=
y
2
10y_1=y_2
10y1=y2
题目除了放缩系数不同以外,别的都一样。因此,防止梯度爆炸,就要减少方差,让k的取值主要在1附近。减小方差的手段有两种,一种是对方差直接变化,乘0.1。第二种是对x进行变化,然后乘一个幅值。代码如下:
import math
import numpy as np
u = 0 # 均值μ
sig = math.sqrt(1) # 标准差δ
x = np.linspace(u - 3*sig, u + 3*sig, 50) # 定义域
y1 = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2*math.pi)*sig) # 定义曲线函数
plt.plot(x, y1, "r", linewidth=2) # 加载曲线
sig = math.sqrt(1)
x2 = np.linspace(u - 3*sig, u + 3*sig, 50) * sig * 0.1
y2 = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2*math.pi)*sig)
plt.plot(x2, y2, "g", linewidth=2) # 加载曲线
sig = 0.1 * sig
x = np.linspace(u - 3*sig, u + 3*sig, 50)
y3 = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2*math.pi))
plt.plot(x, y3, "b", linewidth=2)
plt.grid(True) # 网格线
plt.show() # 显示
可以看到,绿线和蓝线完全重合,
而且这个例子里,输出的每一个y都是512个x累加,而每次取的x都是独立分布的,因此输出的y实际上均值是0,方差是512,即标准差是
512
\sqrt{512}
512
检验该结论的代码如下:
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
mean = 0.0
var = 0.0
for i in range(10000):
x = torch.randn(512)
a = torch.randn(512, 512)
y = a @ x
# 用item()方法将tensor转换为float
mean = mean + y.mean().item()
var = var + y.pow(2).mean().item()
mean = mean/10000
var = math.sqrt(var/10000)
print(mean)
print(var)
print(math.sqrt(512))
要使得经过矩阵的输出值仍然保持1的标准差,则要给每个输出的值的方差进行调整,即乘个数量开根号的缩放系数。代码如下:
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
mean = 0.0
var = 0.0
for i in range(10000):
x = torch.randn(512)
a = torch.randn(512, 512) * math.sqrt(1./512)
y = a @ x
# 用item()方法将tensor转换为float
mean = mean + y.mean().item()
var = var + y.pow(2).mean().item()
mean = mean/10000
var = math.sqrt(var/10000)
print(mean)
print(var)
print(math.sqrt(512))
运行结果如下:
Xavier初始化,在这篇论文里提到
传统的做法是将网络里的k全部都初始化为-1到1的均匀分布
对于在[a,b]间均匀分布的x,其方差为
(
b
−
a
)
2
12
=
4
12
\frac{(b-a)^2}{12}=\frac{4}{12}
12(b−a)2=124
因此对于输入为标准差为
512
\sqrt{512}
512
的正态分布,再乘上均匀分布,其方差为
512
4
12
=
13.06
\sqrt{512}\sqrt{\frac{4}{12}}=13.06
512
124
=13.06
和下面的python代码验证结果一致
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
mean = 0.0
var = 0.0
for i in range(10000):
x = torch.randn(512)
a = torch.Tensor(512, 512).uniform_(-1, 1)
y = a @ x
# 用item()方法将tensor转换为float
mean = mean + y.mean().item()
var = var + y.pow(2).mean().item()
mean = mean/10000
var = math.sqrt(var/10000)
print(mean)
print(var)
但是用传统方法,初始化使用均匀分布,而不是高斯分布,即放缩
n
\sqrt{n}
n
然后再乘均匀分布,在输出使用激活函数时,效果并不是很好。会出现梯度消失
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
x = torch.randn(512)
for i in range(100):
## Xavier
# a = torch.Tensor(512, 512).uniform_(-1, 1) * math.sqrt(6./512*2)
# traditional
a = torch.Tensor(512, 512).uniform_(-1, 1) * math.sqrt(1. / 512)
x = torch.tanh(x @ a)
# x = x @ a
print(x.mean())
print(x.std())
但是采用Xavier初始化,效果就要好得多。
它其实就是把缩放系数改成了下式
n
i
n_i
ni是该层网络的扇入,
n
i
+
1
n_{i+1}
ni+1是该层网络的扇出。效果不错
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
x = torch.randn(512)
def xavier(fan_in, fan_out):
return torch.Tensor(fan_in, fan_out).uniform_(-1, 1) * math.sqrt(6. / (fan_in + fan_out))
for i in range(100):
## Xavier
# a = torch.Tensor(512, 512).uniform_(-1, 1) * math.sqrt(6./512*2)
a = xavier(512, 512)
x = torch.tanh(x @ a)
# x = x @ a
print(x.mean())
print(x.std())
何凯明的初始化
之前用sigmoid和tanh这种激活函数时,我们希望每一层的输出平均值为0,标准差为1(为啥要为1?以后好好推推xaiver的原论文)
如果不放缩,就会梯度爆炸。测试如下,经过十层网络直接爆炸
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
def xavier(fan_in, fan_out):
return torch.Tensor(fan_in, fan_out).uniform_(-1, 1) * math.sqrt(6. / (fan_in + fan_out))
def relu(x):
# 钳位函数,小于0都设置成0
return x.clamp_min(0.)
x = torch.randn(512)
mean = 0
var = 0
for i in range(10):
a = torch.randn(512, 512)
x = relu(x @ a)
mean = mean + x.mean().item()
var = var + x.pow(2).mean().item()
print(mean/10)
print(math.sqrt(var/10))
而用何凯明的放缩系数,到了100层仍然活蹦乱跳
系数是
2
n
\sqrt{\frac{2}{n}}
n2
import matplotlib.pyplot as plt
import torch
import math
import numpy as np
def xavier(fan_in, fan_out):
return torch.Tensor(fan_in, fan_out).uniform_(-1, 1) * math.sqrt(6. / (fan_in + fan_out))
def relu(x):
# 钳位函数,小于0都设置成0
return x.clamp_min(0.)
x = torch.randn(512)
mean = 0
var = 0
for i in range(100):
a = torch.randn(512, 512) * math.sqrt(2/512)
x = relu(x @ a)
mean = mean + x.mean().item()
var = var + x.pow(2).mean().item()
print(mean/100)
print(math.sqrt(var/100))
所以还是有空回头去看看原论文推推公式。。。震撼到了