论文信息标题Auto-Encoding Variational Bayes会议ICLR 2014单位阿姆斯特丹大学代码https://github.com/dpkingma/vae论文https://arxiv.org/pdf/1312.6114.pdf一、前言生成模型的“不可能三角”在VAE出现之前深度生成模型一直被三个难题卡住后验概率不可算p ( z ∣ x ) p(z|x)p(z∣x)无法直接求解大规模数据训不动传统变分推断不支持小批量SGD采样与推断割裂生成和编码不能一套模型搞定这篇论文直接用变分推断重参数化一把梭哈从此VAE成为生成模型三大支柱之一。二、核心思想一句话讲透编码器Encoder输入图片x xx输出隐变量z zz的分布q ϕ ( z ∣ x ) q_\phi(z|x)qϕ​(z∣x)解码器Decoder输入隐变量z zz输出重建图片p θ ( x ∣ z ) p_\theta(x|z)pθ​(x∣z)训练目标让边缘似然下界最大既保证重建准又保证生成真实通俗解释不是普通自编码器只学“编码→解码”而是学概率分布能从噪声随机采样生成全新图片。三、整体架构图1 VAE概率图模型实线生成模型p θ ( z ) p θ ( x ∣ z ) p_\theta(z)p_\theta(x|z)pθ​(z)pθ​(x∣z)虚线近似后验q ϕ ( z ∣ x ) q_\phi(z|x)qϕ​(z∣x)θ \thetaθ解码器参数ϕ \phiϕ编码器参数四、核心公式全解析4.1 对数似然下界ELBOlog ⁡ p θ ( x ( i ) ) ≥ L ( θ , ϕ ; x ( i ) ) \log p_\theta(x^{(i)}) \ge \mathcal{L}(\theta,\phi;x^{(i)})logpθ​(x(i))≥L(θ,ϕ;x(i))L − D K L ( q ϕ ( z ∣ x ) ∥ p θ ( z ) ) E q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] \mathcal{L} -D_{KL}(q_\phi(z|x) \parallel p_\theta(z)) \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]L−DKL​(qϕ​(z∣x)∥pθ​(z))Eqϕ​(z∣x)​[logpθ​(x∣z)]L \mathcal{L}L证据下界越大越好D K L D_{KL}DKL​KL散度衡量分布差异q ϕ ( z ∣ x ) q_\phi(z|x)qϕ​(z∣x)编码分布近似后验p θ ( z ) p_\theta(z)pθ​(z)先验分布标准高斯p θ ( x ∣ z ) p_\theta(x|z)pθ​(x∣z)解码分布生成图像E \mathbb{E}E期望通俗解释左边让编码靠近先验规范分布右边让重建尽可能准。4.2 重参数化技巧VAE能训的关键z μ σ ⊙ ϵ , ϵ ∼ N ( 0 , I ) z \mu \sigma \odot \epsilon,\quad \epsilon \sim \mathcal{N}(0,I)zμσ⊙ϵ,ϵ∼N(0,I)z zz隐变量采样μ \muμ编码器输出均值σ \sigmaσ编码器输出标准差ϵ \epsilonϵ标准高斯噪声⊙ \odot⊙按元素相乘通俗解释把随机性甩给固定噪声ϵ \epsilonϵ让z zz可导才能用反向传播训练。4.3 高斯先验下的KL闭式解− D K L 1 2 ∑ j 1 J ( 1 log ⁡ σ j 2 − μ j 2 − σ j 2 ) -D_{KL} \frac{1}{2}\sum_{j1}^J \left(1\log\sigma_j^2 - \mu_j^2 - \sigma_j^2\right)−DKL​21​j1∑J​(1logσj2​−μj2​−σj2​)J JJ隐变量维度μ j , σ j \mu_j,\sigma_jμj​,σj​第j jj维的均值、方差五、核心PyTorch代码5.1 VAE Encoder输出μ, logvarimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFclassEncoder(nn.Module):def__init__(self,in_dim784,hidden_dim400,latent_dim20):super().__init__()self.fc1nn.Linear(in_dim,hidden_dim)self.fc_munn.Linear(hidden_dim,latent_dim)self.fc_logvarnn.Linear(hidden_dim,latent_dim)defforward(self,x):hF.relu(self.fc1(x))muself.fc_mu(h)logvarself.fc_logvar(h)returnmu,logvar5.2 VAE DecoderclassDecoder(nn.Module):def__init__(self,latent_dim20,hidden_dim400,out_dim784):super().__init__()self.fc2nn.Linear(latent_dim,hidden_dim)self.fc3nn.Linear(hidden_dim,out_dim)defforward(self,z):hF.relu(self.fc2(z))x_recontorch.sigmoid(self.fc3(h))returnx_recon5.3 重参数化 损失函数classVAE(nn.Module):def__init__(self):super().__init__()self.encoderEncoder()self.decoderDecoder()defreparameterize(self,mu,logvar):stdtorch.exp(0.5*logvar)epstorch.randn_like(std)returnmueps*stddefforward(self,x):mu,logvarself.encoder(x)zself.reparameterize(mu,logvar)x_reconself.decoder(z)# 损失重构损失 KL散度recon_lossF.binary_cross_entropy(x_recon,x,reductionsum)kl_loss-0.5*torch.sum(1logvar-mu.pow(2)-logvar.exp())returnrecon_losskl_loss六、实验结果与对比6.1 对数似然下界对比表格1 出处原论文Figure 2模型MNIST测试集下界Wake-Sleep约105VAE(AEVB)约140表格1 训练收敛速度对比分析VAE收敛更快、更高、更稳完爆传统Wake-Sleep。6.2 隐空间可视化图2 2维隐空间分布分析VAE学到光滑连续的流形数字之间平滑过渡可插值生成。6.3 不同隐维度采样效果图3 不同维度隐变量生成的MNIST分析隐维度≥10即可生成清晰数字维度越高细节越丰富。七、关键创新点SGVB估计器变分下界可微、可小批量训练重参数化技巧解决采样不可导问题AEVB算法编码解码联合训练一套框架搞定生成与推断理论优美为后续CV、NLP生成模型奠定基础八、总结VAE是深度生成模型的里程碑第一次把变分推断和深度网络完美结合用重参数化解决采样不可导的世纪难题支持大规模数据、端到端训练、随机采样生成今天几乎所有可控生成、隐空间分析、概率建模都能看到VAE的影子。