在生成式AI系列教程中,我们从生成对抗网络(GAN)开始,探讨了它的基本原理和应用。在本篇中,我们将深入研究变分自编码器(VAE),一种重要的生成模型,它在数据生成和特征学习中发挥着关键作用。
什么是变分自编码器(VAE)
变分自编码器(VAE)是一种深度学习模型,它通过学习数据的潜在表示,来生成与输入数据相似的新样本。与传统的自编码器不同,VAE通过变分推断来进行隐变量建模,旨在最大化数据的似然估计。
VAE的核心思想是将输入的数据压缩成一个潜在空间(latent space),并从这个潜在空间中重新生成数据。这种生成过程可以通过数学公式进行描述:
$$
p(x | z) = \mathcal{N}(x; \mu(z), \sigma^2(z))
$$
这里,$x$表示输入数据,$z$表示潜在变量,$\mu(z)$和$\sigma^2(z)$分别是基于潜在变量的生成输出的均值和方差。
VAE的组成
VAE由以下几个部分组成:
编码器(Encoder):将输入数据$x$映射到潜在空间中,输出潜在变量的均值$\mu$和方差$\sigma^2$。
$$
q(z | x) = \mathcal{N}(z; \mu(x), \sigma^2(x))
$$
重参数化技巧:为了能够进行反向传播,VAE使用重参数化技巧,将随机采样的过程转换为确定性函数的组合:
$$
z = \mu(x) + \sigma(x) \cdot \epsilon \quad \text{其中 } \epsilon \sim \mathcal{N}(0, I)
$$
解码器(Decoder):将潜在变量$z$映射回数据空间,以生成新的样本:
$$
p(x | z) = \mathcal{N}(x; \mu(z), \sigma^2(z))
$$
损失函数:VAE的损失函数由两部分组成:
- 重构损失(Reconstruction Loss):衡量生成样本与真实样本的相似度。
- KL 散度(Kullback-Leibler Divergence):衡量潜在分布与先验分布的差异。
最终的损失函数为:
$$
\mathcal{L}(x) = \mathbb{E}{q(z|x)}[\log p(x|z)] - D{KL}(q(z|x) || p(z))
$$
VAE案例:手写数字生成
让我们通过一个具体的案例来理解VAE的工作原理。我们将使用PyTorch库实现一个VAE,用于生成手写数字(MNIST数据集)。
数据准备
首先,确保安装好PyTorch和相关库。然后,我们可以加载MNIST数据集:
1 2 3 4 5 6 7 8
| import torch from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
|
VAE模型实现
接下来,我们定义VAE模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| import torch.nn as nn import torch.nn.functional as F
class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(28*28, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 28*28)
def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std
def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3))
def forward(self, x): mu, logvar = self.encode(x.view(-1, 28*28)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar
|
训练模型
训练VAE模型是优化损失函数的过程,这里我们使用Adam优化器:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| def loss_function(recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD
model = VAE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train() for epoch in range(10): for batch_idx, (data, _) in enumerate(train_loader): data = data.to(torch.device("cpu")) optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward() optimizer.step()
|
生成新样本
训练完成后,可以用VAE生成新的手写数字样本:
1 2 3 4
| model.eval() with torch.no_grad(): z = torch.randn(64, 20) sample = model.decode(z).cpu()
|
生成的样本可以通过Matplotlib可视化:
1 2 3 4 5 6 7 8 9 10 11 12
| import matplotlib.pyplot as plt
sample = sample.view(-1, 28, 28)
def show_samples(samples): fig, axes = plt.subplots(8, 8, figsize=(10, 10)) for i, ax in enumerate(axes.flat): ax.imshow(samples[i].numpy(), cmap='gray') ax.axis('off') plt.show()
show_samples(sample)
|
小结
在本篇教程中,我们详细介绍了变分自编码器(VAE)的基本概念、工作原理及实现方法。通过手写数字生成的案例,我们发现VAE不仅