这篇文章是 数学建模 课程的非标报告,由生成式模型协助完成。实际写于 2025/6/6。

VAE的发展历程

1980s:自编码器的诞生与早期探索

1987年,Yann LeCun首次提出自编码器(Autoencoder, AE),其核心思想是通过编码器-解码器结构压缩数据到低维空间,再重构原始输入。早期的AE主要用于数据降维和特征提取。

典型的AE结构如下:

flowchart LR
A[原始图像] -->B{Encoder}
B --> C[潜空间中的编码向量]
C --> D{Decoder}
D --> E[重构图像]

其编码器、解码器均是一个在不同空间中映射的函数,比编码器将原始数据直接映射到潜空间,而解码器直接将将潜空间映射到原始空间。

但是,这个模型的缺陷也是显而易见的,其存在生成能力不足、潜在空间不连续等问题。例如,传统AE的的潜空间是很割裂的,无法进行插值,也无法进行生成。

1990s-2000s:生成模型的初步尝试

1994年,Hinton等人提出基于AE的生成模型框架,尝试用最小描述长度原理(MDL)建模数据分布。也就是将传统数据压缩算法引入神经网络。

然而,传统生成模型(如高斯混合模型)依赖复杂的推断方法(如MCMC),难以处理高维数据。这一阶段的研究为概率化潜在空间奠定了基础,但计算效率仍是瓶颈。

2013年:VAE的突破性提出

Diederik P. Kingma和Max Welling发表论文《Auto-Encoding Variational Bayes》,首次将变分推断(Variational Inference)引入自编码器,解决了传统AE的生成缺陷。VAE的核心创新包括:

  1. 概率化潜在空间:假设隐变量z服从标准正态分布p(z)=N(0,I)p(z)=\mathcal{N}(0,I),使生成过程可控。
  2. 变分下界(ELBO)优化:联合优化重构误差和KL散度,平衡生成质量与分布匹配。

模型的构成

符号说明

符号 解释
Encoder 编码器,是一种函数,将图像输入映射为压缩表示:RH×W×CRLR^{H\times W\times C}\rightarrow R^L
Decoder 解码器,是一种函数,将压缩表示重建为原始图像:RLRH×W×CR^L \rightarrow R^{H\times W\times C}
潜空间 压缩表示所在的空间:RLR^L
z 图像的压缩表示:zRLz\in R^L
x 原始图像:xRH×W×Cx\in R^{H\times W\times C}

模型解释

VAE是AE(Autoencoders,自编码器)的一种,同样也包含了Encoder和Decoder,Encoder负责将数据降维,Decoder负责将数据升维复原。
但VAE巧妙地将概率统计融入了深度神经网络。将Encoder的输出视为一个概率分布,然后从这个分布中采样得到一个新的向量,这个新的向量就可以作为Decoder的输入。将一个概率论问题转化为一个优化问题,并通过深度神经网络强大的拟合能力去优化参数。从而提升了模型的泛化能力与鲁棒性,并增强了模型的可解释性。

对于AE而言,其潜空间缺乏连续性,且容易过拟合,从而会导致其Encoder、Decoder的泛化能力较差。

而VAE引入了统计学的观点,Decoder作为一个生成模型,那是不是意味着Decoder能给出一个概率分布:

xP(XZ=z)\begin{gather*} x\sim P(X|Z=z) \end{gather*}

那么,我们只需要取一个似然比较大的xx,就可以得到一个重构图像。

在这个思想下,对于特定图像,其在潜空间的表示xx同样是一个概率分布xP(ZX=x)x\sim P(Z|X=x)

当然,VAE为了能进行推导建模,也将实际情况做了一些简化。VAE将各个空间中的向量元素都看作独立正态分布

flowchart LR
A[原始图像] -->B{Encoder}
B --> C[均值μᵢ]
B --> D[方差σᵢ]
C --> E["对分布N(μᵢ, σᵢ²)采样得Zᵢ"]
D --> E
E --> F{Decoder}
F --> G[均值μᵢ]
F --> H[方差σᵢ]
G --> I["对分布N(μᵢ, σᵢ²)采样得Xᵢ"]
H --> I

但在实际应用中,我们会直接将Decoder输出的μ\mu作为结果,而不进行采样操作。因为对于正态分布而言,当x=μx=\mu时的似然一定是最大的,因此我们可以直接将μ\mu作为结果。

flowchart LR
A[原始图像] -->B{Encoder}
B --> C[均值μᵢ]
B --> D[方差σᵢ]
C --> E["对分布N(μᵢ, σᵢ²)采样得Zᵢ"]
D --> E
E --> F{Decoder}
F --> G[重构图像]

Note:其中菱形表示函数,方框表示值。

而在隐空间中,其中对于分布的采样操作,会强迫网络学到连续性的潜空间(因为同一个N(μ,σ2)N(\mu,\sigma^2)的采样结果是随机的),从而提升模型的泛化能力与鲁棒性。

模型的求解

符号定义

符号 解释
ϕ\phi Encoder的参数
θ\theta Decoder的参数
pθ(X)p_{\theta}(X) 在参数θ\theta下图像的概率分布
pθ(XZ)p_{\theta}(X|Z) X的后验分布(Decoder)
pθ(Z)p_{\theta}(Z) ZZ的先验分布
pθ(ZX)p_{\theta}(Z|X) ZZ的真实后验分布
qϕ(ZX)q_{\phi}(Z|X) ZZ的后验分布(Encoder拟合出来的)

变分推断

VAE的核心在于V,Variational Inference,即变分推断,通过近似后验分布来进行建模。也就是利用拟合的后验分布qϕ(ZX)q_{\phi}(Z\mid X)对隐变量进行建模。

在这个视角下,Encoder其实就是p(ZX)p(Z|X),而Decoder其实就是p(XZ)p(X|Z)

首先,我们的目标是最大化Decoder产生真实图像的似然,即最大化p(X=x)p(X=x)

logpθ(X=x)=[zqϕ(Z=zX=x)dz]logpθ(X=x)=zqϕ(Z=zX=x)logpθ(X=x)dz=zqϕ(Z=zX=x)logpθ(X=x,Z=z)pθ(Z=zX=x)dz贝叶斯定理=zqϕ(Z=zX=x)log(pθ(X=x,Z=z)qϕ(Z=zX=x)qϕ(Z=zX=x)pθ(Z=zX=x))dz=zqϕ(Z=zX=x)logpθ(X=x,Z=z)qϕ(Z=zX=x)dz+zqϕ(Z=zX=x)logqϕ(Z=zX=x)pθ(Z=zX=x)dz=(pθ,qϕ)+DKL(qϕ,pθ)Let (pθ,qϕ)=logpθ(X=x)DKL(pθ,qϕ)(pθ,qϕ)\begin{aligned} \log p_{\theta}(X=x) &= [\int_z q_{\phi}(Z=z \mid X=x) dz]\log p_{\theta}(X=x)\\ &=\int_{z} q_{\phi}(Z=z \mid X=x) \log p_{\theta}(X=x) dz\\ &= \int_{z} q_{\phi}(Z=z \mid X=x) \log \frac{p_{\theta}(X=x,Z=z)}{p_{\theta}(Z=z \mid X=x)} dz \quad \text{贝叶斯定理}\\ &= \int_{z} q_{\phi}(Z=z \mid X=x) \log \left( \frac{p_{\theta}(X=x,Z=z)}{q_{\phi}(Z=z \mid X=x)} \cdot \frac{q_{\phi}(Z=z \mid X=x)}{p_{\theta}(Z=z \mid X=x)} \right) dz \\ &= \int_{z} q_{\phi}(Z=z \mid X=x) \log \frac{p_{\theta}(X=x,Z=z)}{q_{\phi}(Z=z \mid X=x)} dz + \\&\int_{z} q_{\phi}(Z=z \mid X=x) \log \frac{q_{\phi}(Z=z \mid X=x)}{p_{\theta}(Z=z \mid X=x)} dz \\ &= \ell(p_{\theta}, q_{\phi}) + D_{KL}(q_{\phi}, p_{\theta}) \quad \text{Let } \ell(p_{\theta}, q_{\phi})=\log p_\theta(X=x)-D_\text{KL}(p_{\theta}, q_{\phi}) \\ &\geq \ell(p_{\theta}, q_{\phi}) \end{aligned}

也就是说,我们需要最大化(pθ,qϕ)\ell(p_{\theta}, q_{\phi}),这样就能让Decoder产生真实图像的似然最大化。

(pθ,qϕ)=logpθ(X=x)DKL(pθ,qϕ)\ell(p_{\theta}, q_{\phi})=\log p_\theta(X=x)-D_\text{KL}(p_{\theta}, q_{\phi})

在最大化\ell的时,我们同时增大了pθ(X=x)p_{\theta}(X=x),并减小了DKL(pθ,qϕ)D_\text{KL}(p_{\theta}, q_{\phi})。意味着Decoder产生真实图像的似然增大,而Encoder拟合的后验分布qϕ(ZX)q_{\phi}(Z\mid X)与真实后验分布pθ(ZX)p_{\theta}(Z\mid X)的距离减小。

(pθ,qϕ)=zqϕ(Z=zX)logpθ(X,Z=z)qϕ(Z=zX)dz=zqϕ(Z=zX)logpθ(XZ=z)p(Z=z)qϕ(Z=zX)dz贝叶斯定理=zqϕ(Z=zX)logp(Z=z)qϕ(Z=zX)dz+zqϕ(Z=zX)logpθ(XZ=z)dz=DKL(qϕ,p)+Eqϕ[logpθ(XZ=z)]\begin{aligned} \ell(p_{\theta}, q_{\phi}) &= \int_{z} q_{\phi}(Z=z \mid X) \log \frac{p_{\theta}(X,Z=z)}{q_{\phi}(Z=z \mid X)} dz \\ &= \int_{z} q_{\phi}(Z=z \mid X) \log \frac{p_{\theta}(X \mid Z=z)p(Z=z)}{q_{\phi}(Z=z \mid X)} dz \quad \text{贝叶斯定理}\\ &= \int_{z} q_{\phi}(Z=z \mid X) \log \frac{p(Z=z)}{q_{\phi}(Z=z \mid X)} dz + \int_{z} q_{\phi}(Z=z \mid X) \log p_{\theta}(X \mid Z=z) dz \\ &= -D_{KL}(q_{\phi}, p) + \mathbb{E}_{q_{\phi}}[\log p_{\theta}(X \mid Z=z)] \end{aligned}

因为zN(0,1)=p(z)z\sim N(0,1)=p(z)

因此,我们希望能最小化正则化项DKL(qϕ,N(0,1))D_{KL}(q_{\phi}, N(0,1)),并最大化Eqϕ[logpθ(Xz)]\mathbb{E}_{q_{\phi}}[\log p_{\theta}(X \mid z)]

减小正则化项

正则项即是DKL(qϕ,N(0,1))D_{KL}(q_{\phi}, N(0,1))

正则项的直观解释:

  1. 若Encoder学会了投机取巧,将σ\sigma始终输出为0,那么VAE就退化为了普通的AE。
  2. VAE同时也是生成模型,也就是说,我们希望潜空间上的任何一点都包含有效信息,即我们将噪声(一般由N(0,1)N(0,1)生成)作为输入数据,Decoder同样能够解码出有效的图像。

显然,正则化项(与正态分布的KL散度)能很好地解决这个问题,在模型的训练过程中,惩罚模型偏离标准正态分布的行为。

下图显示了一个比较极端的例子(潜空间上一维的特征):

一个比较极端的例子

如何不使用正则化项,那么对于x=0x=0的点,Decoder无法理解,因为潜空间实际上是很大的,极有可能某些子空间没有被使用。而当有正则化项时,Decoder会学会该特征是一种介于特征1与特征2之间的特征。

数学推导

P=N(μ,σ2);Q=N(0,1)P=N(\mu,\sigma^2);Q=N(0,1)

DKL(PQ)=p(x)(lnp(x)lnq(x))dxD_{\text{KL}}(P | Q) = \int p(x) (\ln p(x) - \ln q(x) ) dx

代入 PDF:

lnp(x)=ln(σ2π)(xμ)22σ2,lnq(x)=ln(2π)x22\ln p(x) = -\ln(\sigma \sqrt{2\pi}) - \frac{(x - \mu)^2}{2\sigma^2}, \quad \ln q(x) = -\ln(\sqrt{2\pi}) - \frac{x^2}{2}

lnp(x)lnq(x)=lnσ(xμ)22σ2+x22\ln p(x) - \ln q(x) = -\ln \sigma - \frac{(x - \mu)^2}{2\sigma^2} + \frac{x^2}{2}

积分计算:

  1. p(x)(lnσ)dx=lnσ\int p(x) (-\ln \sigma) dx = -\ln \sigma

  2. p(x)((xμ)22σ2)dx=12\int p(x) \left(-\frac{(x - \mu)^2}{2\sigma^2}\right) dx = -\frac{1}{2}(因为 EP[(xμ)2]=σ2\mathbb{E}_P[(x - \mu)^2] = \sigma^2

  3. p(x)x22dx=12(σ2+μ2)\int p(x) \frac{x^2}{2} dx = \frac{1}{2} (\sigma^2 + \mu^2)(因为 EP[x2]=σ2+μ2\mathbb{E}_P[x^2] = \sigma^2 + \mu^2

求和得:

DKL=lnσ12+12(σ2+μ2)=12μ2+12σ2lnσ12D_{\text{KL}} = -\ln \sigma - \frac{1}{2} + \frac{1}{2} (\sigma^2 + \mu^2) = \frac{1}{2} \mu^2 + \frac{1}{2} \sigma^2 - \ln \sigma - \frac{1}{2}

因此:

DKL=12(μ2+σ2ln(σ2)1)D_{KL}=\frac{1}{2} \left( \mu^2 + \sigma^2 - \ln(\sigma^2) - 1 \right)

由于已经假设了zz中各分量都是独立的,因此整个潜空间的正则化损失为:

lossKL=12i=1L(μi2+σi2ln(σi2)1)\text{loss}_{KL}=\frac{1}{2} \sum_{i=1}^{L} \left( \mu_i^2 + \sigma_i^2 - \ln(\sigma_i^2) - 1 \right)

最大化真实图像的似然

我们的目标是最大化期望Eqϕ[logpθ(Xz)]\mathbb{E}_{q \phi}[\log p_{\theta}(X \mid z)]

我们再回顾一下我们的前提

  1. μi\mu_i' 是解码器 pθ(xz)p_{\theta}(x \mid z) 输出的均值(即重构后的数据点均值),对应于通过编码器 qϕ(zxi)q_{\phi}(z \mid x_i) 采样得到的潜在变量 ziz_i
  2. nn 是数据集大小。
  3. 数据集 X={x1,x2,,xn}X = \{x_1, x_2, \dots, x_n\} 是独立同分布(i.i.d.)的。

解码器 pθ(xz)p_{\theta}(x \mid z) 是各向同性的高斯分布:pθ(xz)=N(x;μθ(z),σ2I)p_{\theta}(x \mid z) = \mathcal{N}(x; \mu_{\theta}(z), \sigma^2 I),其中 μθ(z)\mu_{\theta}(z) 是解码器输出的均值向量,σ2\sigma^2 是方差,I 是单位矩阵。

编码器 qϕ(zxi)q_{\phi}(z \mid x_i) 定义后验分布,采样 zqϕ(zxi)z \sim q_{\phi}(z \mid x_i) 用于估计期望。

重构项是对每个数据点独立处理的,因此 Eqϕ[logpθ(Xz)]\mathbb{E}_{q_{\phi}}[\log p_{\theta}(X \mid z)] 实际上应理解为对数据点求和的形式 i=1nEqϕ(zxi)[logpθ(xiz)]\sum_{i=1}^n \mathbb{E}_{q_{\phi}(z \mid x_i)}[\log p_{\theta}(x_i \mid z)]

其中,Eqϕ(zxi)[logpθ(xiz)]\mathbb{E}_{q_{\phi}(z \mid x_i)}[\log p_{\theta}(x_i \mid z)] 是重构项,我们需最大化此项。

给定解码器为高斯分布:

pθ(xiz)=N(xi;μθ(z),σ2I)p_{\theta}(x_i \mid z) = \mathcal{N}(x_i; \mu_{\theta}(z), \sigma^2 I)

其对数似然为:

logpθ(xiz)=12σ2 xiμθ(z) 2d2log(2π)dlogσ\log p_{\theta}(x_i \mid z) = -\frac{1}{2\sigma^2} \ x_i - \mu_{\theta}(z)\ ^2 - \frac{d}{2} \log(2\pi) - d \log \sigma

dd 是数据点 xix_i 的维度。

第二项和第三项是常数(不依赖于 z 或数据)。因此,简化为:

logpθ(xiz)=12σ2 xiμθ(z) 2+C\log p_{\theta}(x_i \mid z) = -\frac{1}{2\sigma^2} \ x_i - \mu_{\theta}(z)\ ^2 + C

其中 C=d2log(2π)dlogσC = -\frac{d}{2} \log(2\pi) - d \log \sigma 是常数。

重构项是期望:

Eqϕ(zxi)[logpθ(xiz)]=Eqϕ(zxi)[12σ2 xiμθ(z) 2+C]\mathbb{E}_{q_{\phi}(z \mid x_i)}[\log p_{\theta}(x_i \mid z)] = \mathbb{E}_{q_{\phi}(z \mid x_i)} \left[ -\frac{1}{2\sigma^2} \ x_i - \mu_{\theta}(z)\ ^2 + C \right]

由于期望是线性的,且 C 是常数:

Eqϕ(zxi)[logpθ(xiz)]=12σ2Eqϕ(zxi)[ xiμθ(z) 2]+C\mathbb{E}_{q_{\phi}(z \mid x_i)}[\log p_{\theta}(x_i \mid z)] = -\frac{1}{2\sigma^2} \mathbb{E}_{q_{\phi}(z \mid x_i)} \left[ \ x_i - \mu_{\theta}(z)\ ^2 \right] + C

这里,Eqϕ(zxi)[( xiμθ(z))2]\mathbb{E}_{q_{\phi}(z \mid x_i)}[( \ x_i - \mu_{\theta}(z))^2 ]是期望重构误差(条件 MSE)。

利用蒙特卡罗估计:期望 Eqϕ(zxi)[]\mathbb{E}_{q_{\phi}(z \mid x_i)}[\cdot] 通常通过采样估计(常用单样本蒙特卡罗估计)。对于每个数据点 xix_i
从编码器采样:ziqϕ(zxi)z_i \sim q_{\phi}(z \mid x_i)

解码器输出均值:μi=μθ(zi)\mu_i' = \mu_{\theta}(z_i)

因此,期望被近似为:

Eqϕ(zxi)[ xiμθ(z) 2] xiμθ(zi) 2= xiμi 2\mathbb{E}_{q_{\phi}(z \mid x_i)} \left[ \ x_i - \mu_{\theta}(z)\ ^2 \right] \approx \ x_i - \mu_{\theta}(z_i)\ ^2 = \ x_i - \mu_i'\ ^2

代入重构项:

Eqϕ(zxi)[logpθ(xiz)]12σ2 xiμi 2+C\mathbb{E}_{q_{\phi}(z \mid x_i)}[\log p_{\theta}(x_i \mid z)] \approx -\frac{1}{2\sigma^2} \ x_i - \mu_i'\ ^2 + C

最大化重构项等价于最大化近似值:

max(12σ2 xiμi 2+C)\max \left( -\frac{1}{2\sigma^2} \ x_i - \mu_i'\ ^2 + C \right)

由于 12σ2<0-\frac{1}{2\sigma^2} < 0CC 是常数,最大化这一项等价于最小化  (xiμi)2\ (x_i - \mu_i')^2(因为负的二次项):

max(12σ2 xiμi 2)    min (xiμi)2\max \left( -\frac{1}{2\sigma^2} \ x_i - \mu_i'\ ^2 \right) \iff \min \ (x_i - \mu_i')^2

扩展到整个数据集,整个数据集的重构项是每个数据点的和:

i=1nEqϕ(zxi)[logpθ(xiz)]i=1n(12σ2 xiμi 2+C)\sum_{i=1}^n \mathbb{E}_{q_{\phi}(z \mid x_i)}[\log p_{\theta}(x_i \mid z)] \approx \sum_{i=1}^n \left( -\frac{1}{2\sigma^2} \ x_i - \mu_i'\ ^2 + C \right)

最大化这个总和等价于最小化:

i=1n(xiμi)2\sum_{i=1}^n (x_i - \mu_i')^2

模型求解

因此,VAE的优化目标可以表示为:

mini=1n(xiμi)2+12i=1L(μi2+σi2ln(σi2)1)\min \sum_{i=1}^n ( x_i - \mu_i')^2 + \frac{1}{2} \sum_{i=1}^{L} \left( \mu_i^2 + \sigma_i^2 - \ln(\sigma_i^2) - 1 \right)

通常使用Adam优化器来最小化这个目标函数。

实例 展示

以Fashion-MNIST作为数据集,VAE的效果如下:

(代码参见 VAE示例的实现

*:由于Encoder、Decoder、超参数的选择不同,VAE的效果也会有所不同。

VAE的应用

乍一看,VAE不过就是压缩了一下数据,然后再解压出来,似乎就是一个压缩图片体积的东西。但实际上,VAE的应用远不止于此。

这个压缩其实是非常有作用的。我们不妨看看原始的图像空间是什么:其数据量大,信息密度低,而对于潜空间呢,其数据量小,信息密度高,并且还很具有连续性!

图像生成方面

既然VAE解决了潜空间不连续的问题,那么我们就可以利用VAE来进行图像生成。我们只需要在潜空间上随机采样,那么我们将其传入Decoder,就可以得到一个新的图像。

我们在之前又假设了潜空间中的数据符合独立标准正态分布,那么我们直接从N(0,I)N(0,I)中采样即可。

以下图像就是从N(0,I)N(0,I)中采样得到的图像(Fashion-MNIST作为数据集):

(代码参见 VAE示例的实现

CVAE:条件VAE[3]

不过,很容易发现上述VAE作为图像生成器存在一个问题,就是我们无法指定生成的类别。

CVAE的核心改进是:引入条件变量。

VAE的生成过程是无监督的——输入数据xx通过编码器映射到潜在空间zz,再从zz随机采样生成输出xx',因此用户无法控制生成内容(例如生成特定类别“裤子”)。
而CVAE(Conditional VAE)的关键改进是加入条件变量yy(如类别标签、文本描述等),使生成过程变为有监督:
训练阶段:编码器和解码器同时接收条件yy,学习p(zx,y)p(z|x,y)p(xz,y)p(x| z,y)

生成阶段:通过指定yy(如标签“裤子”),控制生成结果(裤子的图片)。

VAE的潜在空间:zN(0,I)z \sim \mathcal{N}(0,I),完全随机。

CVAE的潜在空间:zN(μy,σy)z \sim \mathcal{N}(\mu_y, \sigma_y),其中μy\mu_yσy\sigma_y由条件yy决定。例如,不同类别的数据对应不同的高斯分布。

损失函数:在VAE的ELBO(证据下界)基础上,增加条件约束:

LCVAE=Eq(zx,y)[logp(xz,y)]βKL(q(zx,y) p(zy)) \mathcal{L}_{\text{CVAE}} = \mathbb{E}_{q(z|x,y)}[\log p(x|z,y)] - \beta \cdot \text{KL}(q(z|x,y) \ p(z|y))

其中β\beta调节生成质量与条件匹配的平衡。

不难发现,CVAE更进一步拓展了应用场景

CVAE解决了VAE的生成不可控问题,适用于:

  1. 图像生成:指定类别生成图片、黑白图上色。
  2. 多模态能力:根据文本描述生成图像。

​​VAE-GAN

VAE-GAN是一种结合了​​变分自编码器(VAE)​​和​​生成对抗网络(GAN)​​的深度学习模型。其核心思想是​​用VAE的结构提供稳定的潜在空间,再用GAN的判别器提升生成质量​​,从而解决传统VAE生成模糊图像和GAN训练困难的问题

其就是在VAE的基础上,加入了GAN的判别器,用于提高生成质量。

总损失函数可表示为:

LVAE-GAN=LVAE+λLGAN\mathcal{L}_{\text{VAE-GAN}} = \mathcal{L}_{\text{VAE}} + \lambda \mathcal{L}_{\text{GAN}}

其中 λ\lambda 平衡两部分权重。

多任务学习

这基于这样一个事实,既然潜空间编码能够有效地表示原始图像,而且潜空间中的信息密度很高,那么是否意味着我们可以直接利用潜空间编码替代原始图像,从而进行下游任务。

半监督图像分类模型

得益于VAE是无监督模型,那么我们将其与有监督模型进行结合,就可以进行半监督学习。

flowchart LR
    %% 左侧输入数据
    A[(无标签数据)]:::unlabeled
    B[(有标签数据)]:::labeled

    %% 中间特征提取
    C[Encoder特征提取]:::encoder
    D{压缩表示}:::latent

    %% 右侧输出分支
    E[Decoder重建图像]:::decoder
    F[重建损失]:::loss
    G[Classifier分类器]:::classifier
    H[分类损失]:::loss

    %% 连接关系
    A --> C
    B --> C
    C --> D
    D --> E --> F
    D --> G --> H

    %% 样式定义
    classDef unlabeled fill:#9f9,stroke:#090
    classDef labeled fill:#ff9,stroke:#990
    classDef encoder fill:#f99,stroke:#900
    classDef latent fill:#99f,stroke:#009
    classDef decoder fill:#f99,stroke:#900
    classDef classifier fill:#f99,stroke:#900
    classDef loss fill:#ddd,stroke:#555

*:潜空间采样未画出

针对有标签数据,我们可以同时优化分类损失与重建损失。
而对于无标签数据,我们只优化重建损失。

局限与展望

局限

相信你也在图中看到了,VAE生成的图像比较模糊,这是因为在推导的过程中我们进行了一个放缩:

logpθ(X=x)=...=(pθ,qϕ)+DKL(qϕ,pθ)Let (pθ,qϕ)=logpθ(X=x)DKL(pθ,qϕ)(pθ,qϕ)\begin{aligned} \log p_{\theta}(X=x) &= ...\\ &= \ell(p_{\theta}, q_{\phi}) + D_{KL}(q_{\phi}, p_{\theta}) \quad \text{Let } \ell(p_{\theta}, q_{\phi})=\log p_\theta(X=x)-D_\text{KL}(p_{\theta}, q_{\phi}) \\ &\geq \ell(p_{\theta}, q_{\phi}) \end{aligned}

我们只是尽量最大化了这个下界,换句话说,我们只是确保了最糟糕的情况没有这么糟糕。因此,我们难以指望VAE生成的图片会有多好。

展望

虽然VAE生成的图像很糊,但实际上,VAE只是图像生成的一个起点。VAE将概率统计引入了深度学习,为后来的扩散模型提供了一个基础。扩散模型将“降噪算法”引入了潜空间,最终确保生成的图片的清晰度。

比如如今的腾讯混元图像生成模型,将扩散模型建立在VAE之上[1]

以下是使用混元图像生成模型生成的一些图片:

prompt: 夏日阳光下的四川大学江安校区,绿树成荫,校园建筑错落有致,红砖外墙与玻璃幕墙在阳光下熠熠生辉,远处可见明远湖波光粼粼,学生们在林荫道上漫步,画面充满生机与学术氛围。风格是写实油画,比例是4:3

yuanbao1 yuanbao2

换句话说,你看到的几乎所有AI生成的图像,几乎都是由扩散模型生成的,其实现的基础,正是VAE。

支持代码

VAE示例的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()

self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)

self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)

def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def decode(self, z):
return self.decoder(z)

def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar

def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD

transform = transforms.Compose([
transforms.ToTensor()
])

train_dataset = torchvision.datasets.FashionMNIST(
root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()

if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')

print(f'====> Epoch: {epoch} Avg loss: {train_loss / len(train_loader.dataset):.4f}')

def plot_original_and_reconstructed(index):
data, _ = train_dataset[index]
data = data.to(device)
model.eval()
with torch.no_grad():
recon_data, _, _ = model(data.unsqueeze(0))
original_image = data.cpu().view(28, 28).numpy()
reconstructed_image = recon_data.cpu().view(28, 28).numpy()
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(original_image, cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(reconstructed_image, cmap='gray')
axes[1].set_title('Reconstructed Image')
axes[1].axis('off')

plt.show()

def random_generate():
model.eval()
with torch.no_grad():
z = torch.randn(64, 20).to(device)
recon_data = model.decode(z)
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
image = recon_data[i].cpu().view(28, 28).numpy()
ax.imshow(image, cmap='gray')
ax.axis('off')
plt.show()


def doit():
epochs = 10
for epoch in range(1, epochs + 1):
train(epoch)
def draw():
for i in range(4):
plot_original_and_reconstructed(i)

doit()

draw()

参考资料

  1. ​​Li et al. (2024). Hunyuan-DiT: A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding. arXiv preprint arXiv:2405.08748.​
  2. Kingma, Diederik P., and Max Welling. “Auto-Encoding Variational Bayes.” Statistics and Probability Letters, vol. 162, Sept. 2020, p. 108773. doi:10.1016/j.spl.2020.108773.
  3. Ramchandran, S., Tikhonov, G., Lönnroth, O., Tiikkainen, P., & Lähdesmäki, H. (2022). Learning Conditional Variational Autoencoders with Missing Covariates. arXiv preprint arXiv:2203.01218.
  4. https://zhuanlan.zhihu.com/p/348498294