动机
- 在无监督情况下如何学习有用的表征很关键,希望实现一个模型在其隐空间保存数据重要特征的同时优化最大似然;
- 而采用离散编码更容易对先验建模,且连续的表征通常会被网络内在地离散化。
贡献
- 引入矢量量化的思想,将
VAE
与离散隐空间相结合,提出VQVAE
,它避免了posterior collapse
和variance issues
,且与连续编码模型表现相当; - 当离散隐编码与自回归
prior
配对使用,模型可在speech and video generation
等应用生成高质量的一致性样本,以及无监督学习。
本文的方法
VAE
通由以下部分构成:encoder
建模给定输入的离散隐编码的posterior distribution
,prior distribution
和decoder
建模likehliood
分布。通常假设prior distribution
为标准正态分布、posterior distribution
为对角方差的多元正态分布,隐编码为连续的随机变量。
VQVAE
跟VAE
主要的不同在于:1)encoder
的输出是离散的而不是连续的;2)prior
是可学习的而不是静态的。如下图所示,对于输入,encoder
输出,通过vector quantisation
使用离散隐编码,即通过最近邻从隐空间中找出离散隐变量,则先验分布和后验分布是可分类的,分布中的分量为映射embedding table
的index
,并将其作为decoder
的输入。
Discrete Latent variables
定义隐空间,其中为离散隐空间的大小,为每个隐编码的维度,即。如下图所示,对于输入,encoder
输出,通过最近邻从隐空间中找出离散隐变量,VAE
的后验分布为多元高斯分布,而VQVAE
的posterior categorical distribution
为one-hot
形式:
其中,为encoder
的输出,再通过最近邻从隐空间映射量化得到decoder
的输入,即:
Learning
类似VAE
,使用ELBO
约束,通过定义先验为均匀分布,则proposal distribution
为确定的,且KL divergence
为常数,即:
故相比VAE
,VQVAE
训练过程中未使用先验分布,仅使用重建损失,则第二阶段需要训练先验模型用以生成图像。由于argmin
导致不可导,则类似staright-through estimator
近似估计梯度,即直接将decoder
的输入的梯度复制给encoder
的输出。在前向计算时,最近邻的embedding
被传递给decoder
,在反向传播时,梯度被直接传递给encoder
,由于和共享空间,所以梯度包含encoder
如何改变输出去降低重构损失的信息。但embedding
没有从重构损失中接收到梯度,为了让embedding
参与训练,使用Vector Quantisation
,即计算encoder
的输出和向量量化后得到的embedding
之间的误差(也可以通过EMA
对embedding
进行更新)。同时,为了限制embedding
空间和encoder
输出的一致性,避免encoder
的输出变动过大,额外增加了一项commitment loss
,则总的损失函数如下:
其中,表示stop gradient
,decoder
只优化第一项重建损失,encoder
优化损失的第一项最后一项,embedding
优化中间的commitment loss
,文章所有实验均设置。对于一张图像,在计算第二项和第三项损失时将计算个离散隐编码的损失的平均值。
Prior
训练VQVAE
时,先验分布恒定为均匀分布,训练完成后需要在上拟合一个自回归分布,以通过抽样生成离散编码。即使用VQVAE
对训练图像进行推理得到对应的离散编码,再采用PixelCNN
对离散编码建模,基于softmax
预测类别。生成图像时则通过训练好的PixelCNN
采样离散编码,再送入VQVAE
的decoder
生成图像。
部分实验结果
在ImageNet
上重建的结果
在ImageNet
训练的VQVAE
,之后从PixelCNN
采样先验后生成的结果