cff064de34ba73a61fe5cebd16a39f38.png

本文作者:Light Sea@知乎。未经作者允许,本文禁止转载,谢谢合作。

原论文题目《Deep unsupervised clustering with gaussian mixture variational autoencoders》。

本文我将介绍VAE针对无监督聚类的一个扩展:GMVAE,即基于高斯混合模型的VAE。我们在之前的文章中已经介绍了VAE,它是一个无监督的生成模型,其良好的性能和end-to-end的性质让它在深度学习时代被广泛应用。而GMVAE则将VAE的相关技术应用到无监督聚类问题上,其思想在于通过扩展latent variable structure提升VAE的聚类性能。

1. Introduction

无监督聚类一直是机器学习中比较重要的题目,传统方法例如我们熟知的k-means或者高斯混合模型(GMM)到现在仍然被广泛使用。但在深度学习快速发展的今天,作者希望可以借助深度学习的力量来提升这个问题的表现。

传统的deep generative model,例如VAE,并不是专门为了聚类而设计的。但是像VAE这样的深度生成模型中包含了一个重要的东西:latent variable。实际上潜变量建模了一些抽象的概念,潜变量的结构在某些时候就包含了类别信息。举个例子,如果我们使用MNIST数据来训练VAE,我们就能观察到不同种类数字的潜变量实际上在潜变量空间中形成了不同的cluster。

然而这样虽然有可能完成聚类,但却没有任何显式的能完成聚类的保证。另外,潜变量的先验分布很简单,因此不能建模复杂数据。因此在这篇文章中作者基于VAE设计了一个专门用于聚类的算法。算法假设数据生成于一个多模态(multimodal)的先验分布。作者构建了一个end-to-end的模型,从而可以用BP轻易更新整个网络。

2. Method

2.1 Generative and recognition models

作者选择混合高斯作为先验分布。我们考虑一个生成模型:

84d4b475bf292d28c8ea0763aba89462.png

这里

都是潜变量,
是要生成的模型,
是模型参数。

对这几个潜变量有下面的假设:

cabe549f72d75c3d2a5fd53b04094371.png

什么意思呢?模型首先从标准高斯分布中采样一个

,然后再用这个作为输入从以
为参数的神经网络中生成均值向量
和标准差向量
。这实际上是混合高斯中每个类别的均值和标准差,但是数据只可能属于一类,因此引入了多项分布随机变量
,假设
遵从离散的
值均匀分布,通过采样就可以得到
。这样我们就得到了
的分布,即混合高斯中第
个成分的分布。

假设随见变量

遵从高斯分布,我们用
作为输入到从以
作为参数的神经网络中生成均值
和标准差
就得到了
的分布。

下图展示了这个模型的graphical model:

4d72b856826f34ed8cfe6687d468ac89.png

2.2 Inference with the recognition model

模型通过优化ELBO来训练:

384b1527abb208bfc8659dddd4b33ff2.png

应用一些独立性假设,我们得到下面的分解:

30283872c8e40c6d54498d34c088e7d5.png

这里

代表数据的index,在之后的公式中我们省略它。

的后验分布
可以写成:

1da72a856524e4055d54a5867f3adfd0.png

经过推导可以得到ELBO的简化表达式:

86a727d04ffa9dc3b5eaeb4ea60d6e53.png

我们分别称上式右边的四个项为reconstruction term, conditional prior term,

-prior term和
-prior term。

我们分别讨论这几项。

这里Reconstruction term其实就是最常见的重构建损失,用相应的损失函数比如MSE就可以;而

-prior就和普通的VAE中的KL项类似,因此也可以轻易计算。

我们重点看后面的几项,首先conditional prior term可以用下面的式子来估计:

0323c18fdbb602c4404c2fc18ca882e3.png

这里由于对于所有的

,分布
都能通过网络一次计算得到,因此这个整个模型的参数更新只需要一次前向传播。

2.3 The over-regularization problem

所谓over-regularization problem是指ELBO中KL项被过度优化从而导致KL项迅速变为0的问题。这种情况下,latent variable会变得非常简单且不能很好的表示数据的结构。一般来说有两种办法解决这个问题。一是给KL项添加权重,边训练边增加权重;二是为KL项设置一个阈值,一旦它小于这个值就不再优化KL项。这里作者使用第二种方法来调整

-prior term,相关公式为:

ca3871756e052e4fbd3fe45c2d650757.png

3. Experiment

作者一共做了三个部分的实验,一是在synthetic data上做的用来验证和解决over-regularization问题的实验;二是在MNIST上训练无监督聚类的实验;最后是用随机噪声生成图片的实验。

我们重点看后面两个实验。

3.1 Unsupervised image clustering

作者使用MNIST数据集进行无监督聚类,结果如下表所示:

eac369609cc85427dc2f86a600d8cfff.png

可以看到GMVAE的表现比之前的绝大多数方法要好,但要劣于adversarial auto-encoder(AAE)。

另外作者还观察到大的K和M能增加模型表现和稳定性,如下图所示:

e08f6964bebb780cfc104027ff253817.png

3.2 Image generation

作者设置噪声w为0,然后在混合高斯的不同成分中采样来生成手写数字,结果发现不同成分对应了不同的数字,这就证明GMVAE成功地学习到了不同的类别。

作者还固定混合高斯的成分,然后改变噪声w,发现噪声w实际上控制了数字的风格。

这两个实验的结果如下图所示:

df3d0f7cc3fdafc1e48203474a968879.png

4. Conclusion

GMVAE改进了传统的VAE,使得latent variable structure更加复杂,从而能够更好的完成无监督聚类任务。作者提出的latent variable具有明显的层次结构,这就使得latent variable更易被理解,因此有更好的性能也就不足为奇了。

创作不易,如果各位感觉到有收获请点赞收藏支持一下,你的支持就是我创作的最大动力。

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐