作者:CHEONG

公众号:AI机器学习与知识图谱

研究方向:自然语言处理与知识图谱

阅读本文之前,首先注意以下两点:

1. 机器学习系列文章常含有大量公式推导证明,为了更好理解,文章在最开始会给出本文的重要结论,方便最快速度理解本文核心。需要进一步了解推导细节可继续往后看。

2. 文中含有大量公式,若读者需要获取含公式原稿Word文档,可关注公众号【AI机器学习与知识图谱】后回复:变分推断第二讲,可添加微信号【17865190919】进学习交流群,加好友时备注来自CSDN。原创不易,转载请告知并注明出处!

本文将先对变分推断所要解决的问题进行分析,然后给出基于Mean Field的变分推断解法。


一、本文结论

结论1: 变分推断的主要思想:在给定数据集XXX下,问题是求后验概率ppp,简单情况下后验概率ppp可直接通过贝叶斯公式推导求出,但有些情况无法直接求解。因此变分推断想法是先假设另一个简单的概率分布qqq,如高斯分布,通过优化pppqqq之间距离最小化,让概率分布qqq逼近ppp,这样就可以用概率分布qqq近似表示后验概率ppp

结论2: 基于Mean Field的变分推断方法主要是假设将隐变量zzz分成M个相互独立的部分z=(z1,z2,...,zM)z=(z_1,z_2,...,z_M)z=(z1,z2,...,zM) ,当求qj(zj)q_j(z_j)qj(zj)时固定剩下M-1个部分。

结论3: 基于Mean Field的变分推断方法存在的两个问题:(1)假设将z=(z1,z2,...,zM)z=(z_1,z_2,...,z_M)z=(z1,z2,...,zM)分成M个相互独立的部分,然后固定其他依次求得qj(zj)q_j(z_j)qj(zj)。这个假设太强烈,在一些问题是无法分成相互独立的各个部分;(2)最后求出来的qj(zj)q_j(z_j)qj(zj)仍然需要进行求积分,在一些问题中,仍然可能是Intractable,无法求解的。


二、问题分析

观测数据Observed Data:XXX

隐变量Latent Variable:ZZZ

完整数据Complete Data:(X,Z)(X, Z)(X,Z)

目的: 求数据的后验概率p(z∣x)p(z|x)p(zx),下面先给出变分推断的分析思路

在这里插入图片描述

首先由简单的联合概率分布的分解式引出问题,如下公式所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bnvN45Oj-1617958057057)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image023.png)]

通过两边加log变形为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wD6jQnPH-1617958057060)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image025.png)]

为了近似求解后验概率p(z∣x)p(z|x)p(zx),我们需要先引入另一个分布q(z)q(z)q(z),整合进上面公式中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OgrdLuv6-1617958057065)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image027.png)]

接下来分别将上式的左边和右边部分对q(z)q(z)q(z)进行积分:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-A47Hl0wR-1617958057071)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image031.png)]

其中

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-37r1SigK-1617958057077)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image033.png)]

所以左边在积分后仍然是logp(x)logp(x)logp(x),接下来对右边部分进行积分:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9F0pehv7-1617958057080)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image037.png)]

其中前半部分是Evidence Lower Bound,简称为ELBOELBOELBO

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Y9TMxn46-1617958057085)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image039.png)]

后半部分是概率分布pppqqq的相对熵:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-T4nemFit-1617958057087)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image045.png)]

因此有:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KbtHW55Z-1617958057091)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image047.png)]

因为当数据给定的情况下,左边logp(x)logp(x)logp(x)是定值,即ELBO+KL(q∣∣p)ELBO+KL(q||p)ELBO+KL(qp)是一个定值,而其中KL(q∣∣p)KL(q||p)KL(qp)是大于等于0的,且KL(q∣∣p)KL(q||p)KL(qp)越小代表概率分布pppqqq就越接近,也就是我们要优化的目标,但KL(q∣∣p)KL(q||p)KL(qp)中包含后验概率不好直接优化最小,但因为ELBO+KL(q∣∣p)ELBO+KL(q||p)ELBO+KL(qp)是定值,所以我们可以优化让ELBOELBOELBO部分最大,KL(q∣∣p)KL(q||p)KL(qp)相对就越小,这样便可以用概率分布qqq来代替ppp了。


三、公式推导

通过上一小节的描述已经明确了变分推断需要优化的目标,总结为如下公式:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SJmEbkXc-1617958057095)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image055.png)]

下面通过公式推导求解是的ELBOELBOELBO最大的后验概率q(z)q(z)q(z)的值,使用基于Mean Field的变分推断的解法求解后验概率分布p(z∣x)p(z|x)p(zx)

先假设z=(z1,z2,...,zM)z=(z_1,z_2,...,z_M)z=(z1,z2,...,zM),并且这M份之间是相互独立的,则有:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iPQcUKRe-1617958057099)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image061.png)]

接下来对ELBOELBOELBO项进行展开,并将q(z)q(z)q(z)的值代入:

在这里插入图片描述

下面为了简便,先做一下变量假设:

在这里插入图片描述

在推导AAABBB前,先固定z=(z1,...,zj−1,zj+1...,zM)z=(z_1,...,z_{j-1}, z_{j+1}...,z_M)z=(z1,...,zj1,zj+1...,zM),先zjz_jzj,接下来先推导AAA

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4kDY5r1z-1617958057135)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image081.png)]

其中有:

在这里插入图片描述

因此可以得出AAA的值如下:

在这里插入图片描述

接下来推导BBB

在这里插入图片描述

其中有:

在这里插入图片描述

因此得出了BBB的值:

在这里插入图片描述

因为固定了z=(z1,...,zj−1,zj+1...,zM)z=(z_1,...,z_{j-1}, z_{j+1}...,z_M)z=(z1,...,zj1,zj+1...,zM),只求未知量zjz_jzj,所以:

在这里插入图片描述

其中CCC是常量,至此有:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tpDxZftt-1617958057172)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image099.png)]

因此当KL取0时,ELBOELBOELBO能达到最大值,所以这里求出qj(zj)q_j(z_j)qj(zj)

在这里插入图片描述

其他的q1(z1),q2(z2),,...,qM(zM)q_1(z_1),q_2(z_2),,...,q_M(z_M)q1(z1),q2(z2),,...,qM(zM)求解方法相同。这样求出了q∗(z)q^{*}(z)q(z)求等价于求出了后验概率p(z∣x)p(z|x)p(zx)


正如文章开头结论所说,基于Mean Field的变分推断方法存在的两个问题,下一节变分推断将介绍另一种解法:基于随机梯度上升SGD的变分推断推导方案:

1、假设将z=(z1,z2,...,zM)z=(z_1,z_2,...,z_M)z=(z1,z2,...,zM) 分成M个相互独立的部分,然后固定其他依次求得qj(zj)q_j(z_j)qj(zj)。这个假设太强烈,在一些问题是无法分成相互独立的各个部分;

2、最后求出来的qj(zj)q_j(z_j)qj(zj)仍然是求积分,在一些问题中,仍然可能是Intractable,无法求解的。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐