KL散度(转载自微信公众号机器之心)
首先让我们确立一些基本规则。我们将会定义一些我们需要了解的概念。
分布(distribution)
分布可能指代不同的东西,比如数据分布或概率分布。我们这里所涉及的是概率分布。假设你在一张纸上画了两根轴(即 X 和 Y),我可以将一个分布想成是落在这两根轴之间的一条线。其中 X 表示你有兴趣获取概率的不同值。Y 表示观察 X 轴上的值时所得到的概率。即 y=p(x)。下图即是某个分布的可视化。
这是一个连续概率分布。比如,我们可以将 X 轴看作是人的身高,Y 轴是找到对应身高的人的概率。
如果你想得到离散的概率分布,你可以将这条线分成固定长度的片段并以某种方式将这些片段水平化。然后就能根据这条线的每个片段创建边缘互相连接的矩形。这就能得到一个离散概率分布。
事件(event)
对于离散概率分布而言,事件是指观察到 X 取某个值(比如 X=1)的情况。我们将事件 X=1 的概率记为 P(X=1)。在连续空间中,你可以将其看作是一个取值范围(比如 0.95<X<1.05)。注意,事件的定义并不局限于在 X 轴上取值。但是我们后面只会考虑这种情况。
假设我们是一组正在广袤无垠的太空中进行研究的科学家。我们发现了一些太空蠕虫,这些太空蠕虫的牙齿数量各不相同。现在我们需要将这些信息发回地球。但从太空向地球发送信息的成本很高,所以我们需要用尽量少的数据表达这些信息。我们有个好方法:我们不发送单个数值,而是绘制一张图表,其中 X 轴表示所观察到的不同牙齿数量(0,1,2…),Y 轴是看到的太空蠕虫具有 x 颗牙齿的概率(即具有 x 颗牙齿的蠕虫数量/蠕虫总数量)。这样,我们就将观察结果转换成了分布。
发送分布比发送每只蠕虫的信息更高效。但我们还能进一步压缩数据大小。我们可以用一个已知的分布来表示这个分布(比如均匀分布、二项分布、正态分布)。举个例子,假如我们用均匀分布来表示真实分布,我们只需要发送两段数据就能恢复真实数据;均匀概率和蠕虫数量。但我们怎样才能知道哪种分布能更好地解释真实分布呢?这就是 KL 散度的用武之地。
直观解释:KL 散度是一种衡量两个分布(比如两条线)之间的匹配程度的方法。
让我们对示例进行一点修改,为了能够检查数值的正确性,让我们将概率值修改成对人类更友好的值(相比于上述博文中的值)。我们进行如下假设:假设有 100 只蠕虫,各种牙齿数的蠕虫的数量统计结果如下。
0 颗牙齿:2(概率:p_0 = 0.02)
1 颗牙齿:3(概率:p_1 = 0.03)
2 颗牙齿:5(概率:p_2 = 0.05)
3 颗牙齿:14(概率:p_3 = 0.14
4 颗牙齿:16(概率:p_4 = 0.16)
5 颗牙齿:15(概率:p_5 = 0.15)
6 颗牙齿:12(概率:p_6 = 0.12)
7 颗牙齿:8(概率:p_7 = 0.08)
8 颗牙齿:10(概率:p_8 = 0.1)
9 颗牙齿:8(概率:p_9 = 0.08)
10 颗牙齿:7(概率:p_10 = 0.07)
快速做一次完整性检查!确保蠕虫总数为 100,且概率总和为 1.0.
蠕虫总数 = 2+3+5+14+16+15+12+8+10+8+7 = 100
概率总和 = 0.02+0.03+0.05+0.14+0.16+0.15+0.12+0.08+0.1+0.08+0.07 = 1.0
可视化结果为:
尝试 1:使用均匀分布建模
我们首先使用均匀分布来建模该分布。均匀分布只有一个参数:均匀概率;即给定事件发生的概率。
均匀分布和我们的真实分布对比:
先不讨论这个结果,我们再用另一种分布来建模真实分布。
尝试 2:使用二项分布建模
你可能计算过抛硬币正面或背面向上的概率,这就是一种二项分布概率。我们可以将同样的概念延展到我们的问题上。对于有两个可能输出的硬币,我们假设硬币正面向上的概率为 p,并且进行了 n 次尝试,那么其中成功 k 次的概率为:
二项分布的均值和方差
我们还可以定义二项分布的均值和方差,如下:
均值= np
方差= np(1-p)
均值是什么意思?均值是指你进行 n 次尝试时的期望(平均)成功次数。如果每次尝试成功的概率为 p,那么可以说 n 次尝试的成功次数为 np。
方差又是什么意思?它表示真实的成功尝试次数偏离均值的程度。为了理解方差,让我们假设 n=1,那么等式就成了「方差= p(1-p)」。那么当 p=0.5 时(正面和背面向上的概率一样),方差最大;当 p=1 或 p=0 时(只能得到正面或背面中的一种),方差最小。
现在我们已经理解了二项分布,接下来回到我们之前的问题。首先让我们计算蠕虫的牙齿的期望数量:
有了均值,我们可以计算 p 的值:
均值 = np
5.44 = 10p
p = 0.544
注意,这里的 n 是指在蠕虫中观察到的最大牙齿数。你可能会问我们为什么不把蠕虫总数(即 100)或总事件数(即 11)设为 n。我们很快就将看到原因。有了这些数据,我们可以按如下方式定义任意牙齿数的概率。
鉴于牙齿数的取值最大为 10,那么看见 k 颗牙齿的概率是多少(这里看见一颗牙齿即为一次成功尝试)?
从抛硬币的角度看,这就类似于:
假设我抛 10 次硬币,观察到 k 次正面向上的概率是多少?
从形式上讲,我们可以计算所有不同 k 值的概率图片。其中 k 是我们希望观察到的牙齿数量。图片是第 k 个牙齿数量位置(即 0 颗牙齿、1 颗牙齿……)的二项概率。所以,计算结果如下:
我们的真实分布和二项分布的比较如下:
现在回头看看我们已经完成的工作。首先,我们理解了我们想要解决的问题。我们的问题是将特定类型的太空蠕虫的牙齿数据统计用尽量小的数据量发回地球。为此,我们想到用某个已知分布来表示真实的蠕虫统计数据,这样我们就可以只发送该分布的参数,而无需发送真实统计数据。我们检查了两种类型的分布,得到了以下结果。
均匀分布——概率为 0.0909
二项分布——n=10、p=0.544,k 取值在 0 到 10 之间。
让我们在同一个地方可视化这三个分布:
我们如何定量地确定哪个分布更好?
经过这些计算之后,我们需要一种衡量每个近似分布与真实分布之间匹配程度的方法。这很重要,这样当我们发送信息时,我们才无需担忧「我是否选择对了?」毕竟太空蠕虫关乎我们每个人的生命。
这就是 KL 散度的用武之地。KL 散度在形式上定义如下:
其中 q(x) 是近似分布,p(x) 是我们想要用 q(x) 匹配的真实分布。直观地说,这衡量的是给定任意分布偏离真实分布的程度。如果两个分布完全匹配,那么图片,否则它的取值应该是在 0 到无穷大(inf)之间。KL 散度越小,真实分布与近似分布之间的匹配就越好。
计算 KL 散度
我们计算一下上面两个近似分布与真实分布之间的 KL 散度。首先来看均匀分布:
再看看二项分布:
神经网络中的KL散度
Dkl(p||q) :模型更倾向选择一个分布q,使得它在p具有高概率的地方具有高概率。例如当p具有多个峰时,q选择将这些峰模糊在一起,以便将高概率质量放到所有峰上。
Dkl(q||p):模型倾向选择一个分布q,使得它在p具有低概率的地方具有低概率。当p具有多个峰且这些峰间隔很宽时, q会选择单个峰,以避免将概率质量放在p的多个峰之间的低概率区域。另外,如果这些峰没有别足够强的低概率区域分离,那么也可以选择这个公式来强调高概率的地方具有高概率的结果,此时这个方向的KL散度仍然可能选择模糊这些峰。
KL散度在神经网络中的用法
因为神经网络一般不会预测出分布的概率密度函数,所以以上直接基于概率密度函数的公式没法用,但是可以用经过变换之后的公式来计算两个分布的KL散度。
一般神经网络都是通过预测出分布的均值和方差,再计算与真实分布的KL距离。
1、假设 pnxm是服从原分布的n个样本,每个样本有m个维度。 μp,Σp为真实分布的均值和协方差矩阵, μq,Σq为网络预测的均值方差,那么真实分布和网络预测的分布KL散度为:
2、当真实分布p服从一元标准正态分布,即 p∼N(0, 1), 网络预测的均值和方差为q∼N(μ, σ2),
StackGAN++中的CA的理解
首先,文本描述 t首先由编码器编码,产生文本embedding 。
文本嵌入是非线性的转化为潜在条件变量作为生成器的输入。但是,文字的潜在空间的embedding通常是高维的(>100维),在数据量有限的情况下通常会导致潜在的数据流形的不连续,这是不可取的生成模型学习方法。为了缓解这个问题,作者引入条件增强技术来产生额外的条件变量 ,这边的条件变量不是固定的而是作者从独立的高斯分布随机地采样的。其中平均值μ(φ(t))和对角协方差矩阵Σ(φ(t))是文本embedding的函数。
条件增强在少量图像-文本对的情况下产生更多的训练数据,并有助于对条件流形的小扰动具有鲁棒性。为了进一步强化条件流形的平滑性并避免过拟合,作者在训练过程中将以下正则化项添加到生成器的目标中:
这是标准高斯分布和条件高斯分布之间的KL散度。 在条件增强中引入的随机性有利于对文本进行图像翻译建模,因为相同的句子通常对应于具有各种姿势和外观的对象。
CA代码理解
class GLU(nn.Module): # 第二维大小除以2。
def __init__(self):
super(GLU, self).__init__()
def forward(self, x):
nc = x.size(1)
assert nc % 2 == 0, 'channels dont divide 2!'
nc = int(nc/2)
return x[:, :nc] * F.sigmoid(x[:, nc:])
class CA_NET(nn.Module):
# some code is modified from vae examples
# (https://github.com/pytorch/examples/blob/master/vae/main.py
def __init__(self):
super(CA_NET, self).__init__()
self.t_dim = cfg.TEXT.EMBEDDING_DIM
self.c_dim = cfg.GAN.CONDITION_DIM
self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
self.relu = GLU()
def encode(self, text_embedding):
# 推理网络
# 这里以观测值作为输入,并输出用于潜在表示的条件分布的一组参数。我们将此分布建模为对角高斯模型。在这种情况下,将输出因式分解为高斯均值和对数方差参数(为了数值稳定性使用对数方差而不是直接使用方差)。
x = self.relu(self.fc(text_embedding))
mu = x[:, :self.c_dim]
logvar = x[:, self.c_dim:]
return mu, logvar # 均值, 方差的log。
def reparametrize(self, mu, logvar):
# 重参数化技巧
# 在优化过程中,首先从单位高斯采样,然后乘以标准偏差并加平均值。这样可以确保梯度能够通过样本传递到推理网络参数。
std = logvar.mul(0.5).exp_()
# 高斯分布的噪声
if cfg.CUDA:
eps = torch.cuda.FloatTensor(std.size()).normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def forward(self, text_embedding):
mu, logvar = self.encode(text_embedding)
c_code = self.reparametrize(mu, logvar)
return c_code, mu, logvar