【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法


作者:CHEONG

公众号:AI机器学习与知识图谱

研究方向:自然语言处理与知识图谱

阅读本文之前,首先注意以下两点:

1. 机器学习系列文章常含有大量公式推导证明,为了更好理解,文章在最开始会给出本文的重要结论,方便最快速度理解本文核心。需要进一步了解推导细节可继续往后看。

2. 文中含有大量公式,若读者需要获取含公式原稿Word文档,可关注公众号【AI机器学习与知识图谱】后回复:变分推断第二讲,可添加微信号【17865190919】进学习交流群,加好友时备注来自CSDN。原创不易,转载请告知并注明出处!

本文将先对变分推断所要解决的问题进行分析,然后给出基于Mean Field的变分推断解法。


一、本文结论

结论1: 变分推断的主要思想:在给定数据集 X X X下,问题是求后验概率 p p p,简单情况下后验概率 p p p可直接通过贝叶斯公式推导求出,但有些情况无法直接求解。因此变分推断想法是先假设另一个简单的概率分布 q q q,如高斯分布,通过优化 p p p和 q q q之间距离最小化,让概率分布 q q q逼近 p p p,这样就可以用概率分布 q q q近似表示后验概率 p p p。

结论2: 基于Mean Field的变分推断方法主要是假设将隐变量 z z z分成M个相互独立的部分 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1​,z2​,...,zM​) ,当求 q j ( z j ) q_j(z_j) qj​(zj​)时固定剩下M-1个部分。

结论3: 基于Mean Field的变分推断方法存在的两个问题:(1)假设将 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1​,z2​,...,zM​)分成M个相互独立的部分,然后固定其他依次求得 q j ( z j ) q_j(z_j) qj​(zj​)。这个假设太强烈,在一些问题是无法分成相互独立的各个部分;(2)最后求出来的 q j ( z j ) q_j(z_j) qj​(zj​)仍然需要进行求积分,在一些问题中,仍然可能是Intractable,无法求解的。


二、问题分析

观测数据Observed Data: X X X

隐变量Latent Variable: Z Z Z

完整数据Complete Data: ( X , Z ) (X, Z) (X,Z)

目的: 求数据的后验概率 p ( z ∣ x ) p(z|x) p(z∣x),下面先给出变分推断的分析思路

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

首先由简单的联合概率分布的分解式引出问题,如下公式所示:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

通过两边加log变形为:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

为了近似求解后验概率 p ( z ∣ x ) p(z|x) p(z∣x),我们需要先引入另一个分布 q ( z ) q(z) q(z),整合进上面公式中:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

接下来分别将上式的左边和右边部分对 q ( z ) q(z) q(z)进行积分:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

其中

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

所以左边在积分后仍然是 l o g p ( x ) logp(x) logp(x),接下来对右边部分进行积分:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

其中前半部分是Evidence Lower Bound,简称为 E L B O ELBO ELBO:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

后半部分是概率分布 p p p和 q q q的相对熵:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

因此有:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

因为当数据给定的情况下,左边 l o g p ( x ) logp(x) logp(x)是定值,即 E L B O + K L ( q ∣ ∣ p ) ELBO+KL(q||p) ELBO+KL(q∣∣p)是一个定值,而其中 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)是大于等于0的,且 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)越小代表概率分布 p p p和 q q q就越接近,也就是我们要优化的目标,但 K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)中包含后验概率不好直接优化最小,但因为 E L B O + K L ( q ∣ ∣ p ) ELBO+KL(q||p) ELBO+KL(q∣∣p)是定值,所以我们可以优化让 E L B O ELBO ELBO部分最大, K L ( q ∣ ∣ p ) KL(q||p) KL(q∣∣p)相对就越小,这样便可以用概率分布 q q q来代替 p p p了。


三、公式推导

通过上一小节的描述已经明确了变分推断需要优化的目标,总结为如下公式:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

下面通过公式推导求解是的 E L B O ELBO ELBO最大的后验概率 q ( z ) q(z) q(z)的值,使用基于Mean Field的变分推断的解法求解后验概率分布 p ( z ∣ x ) p(z|x) p(z∣x)

先假设 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1​,z2​,...,zM​),并且这M份之间是相互独立的,则有:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

接下来对 E L B O ELBO ELBO项进行展开,并将 q ( z ) q(z) q(z)的值代入:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

下面为了简便,先做一下变量假设:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

在推导 A A A和 B B B前,先固定 z = ( z 1 , . . . , z j − 1 , z j + 1 . . . , z M ) z=(z_1,...,z_{j-1}, z_{j+1}...,z_M) z=(z1​,...,zj−1​,zj+1​...,zM​),先 z j z_j zj​,接下来先推导 A A A

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

其中有:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

因此可以得出 A A A的值如下:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

接下来推导 B B B:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

其中有:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

因此得出了 B B B的值:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

因为固定了 z = ( z 1 , . . . , z j − 1 , z j + 1 . . . , z M ) z=(z_1,...,z_{j-1}, z_{j+1}...,z_M) z=(z1​,...,zj−1​,zj+1​...,zM​),只求未知量 z j z_j zj​,所以:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

其中 C C C是常量,至此有:

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

因此当KL取0时, E L B O ELBO ELBO能达到最大值,所以这里求出 q j ( z j ) q_j(z_j) qj​(zj​):

【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法

其他的 q 1 ( z 1 ) , q 2 ( z 2 ) , , . . . , q M ( z M ) q_1(z_1),q_2(z_2),,...,q_M(z_M) q1​(z1​),q2​(z2​),,...,qM​(zM​)求解方法相同。这样求出了 q ∗ ( z ) q^{*}(z) q∗(z)求等价于求出了后验概率 p ( z ∣ x ) p(z|x) p(z∣x)。


正如文章开头结论所说,基于Mean Field的变分推断方法存在的两个问题,下一节变分推断将介绍另一种解法:基于随机梯度上升SGD的变分推断推导方案:

1、假设将 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1​,z2​,...,zM​) 分成M个相互独立的部分,然后固定其他依次求得 q j ( z j ) q_j(z_j) qj​(zj​)。这个假设太强烈,在一些问题是无法分成相互独立的各个部分;

2、最后求出来的 q j ( z j ) q_j(z_j) qj​(zj​)仍然是求积分,在一些问题中,仍然可能是Intractable,无法求解的。

上一篇:我是如何一步步编码完成万仓网的


下一篇:【模板】树的直径