郭震 AI公众号:郭震AI

7 生成式AI基础概念:变分自编码器(VAE)

发布日期:

最近更新:

分类: 生成式AI从零教程

预计阅读: 4 分钟

阅读次数: 0

预计阅读4 分钟
结构重点8 个
图文要点6 张
正文规模1.6k 字
变分自编码器概念图查看大图
变分自编码器概念图

VAE 的价值在于把数据压到可采样的潜在空间里。理解潜在空间,就能理解它如何生成新样本。

变分自编码器核对图查看大图
变分自编码器核对图

我会检查重构质量和潜在空间是否平滑,两个指标一起看才完整。

在生成式AI系列教程中,我们从生成对抗网络(GAN)开始,探讨了它的基本原理和应用。在本篇中,我们将深入研究变分自编码器(VAE),一种重要的生成模型,它在数据生成和特征学习中发挥着关键作用。

什么是变分自编码器(VAE)

变分自编码器(VAE)是一种深度学习模型,它通过学习数据的潜在表示,来生成与输入数据相似的新样本。与传统的自编码器不同,VAE通过变分推断来进行隐变量建模,旨在最大化数据的似然估计。

VAE基础概念判断卡查看大图
VAE基础概念判断卡

理解 VAE 时,先看输入如何被编码成潜变量分布,再看采样和解码如何生成新样本。

VAE的核心思想是将输入的数据压缩成一个潜在空间(latent space),并从这个潜在空间中重新生成数据。这种生成过程可以通过数学公式进行描述:

p(xz)=N(x;μ(z),σ2(z))p(x | z) = \mathcal{N}(x; \mu(z), \sigma^2(z))

这里,xx表示输入数据,zz表示潜在变量,μ(z)\mu(z)σ2(z)\sigma^2(z)分别是基于潜在变量的生成输出的均值和方差。

VAE的组成

VAE由以下几个部分组成:

  1. 编码器(Encoder):将输入数据xx映射到潜在空间中,输出潜在变量的均值μ\mu和方差σ2\sigma^2

    q(zx)=N(z;μ(x),σ2(x))q(z | x) = \mathcal{N}(z; \mu(x), \sigma^2(x))
  2. 重参数化技巧:为了能够进行反向传播,VAE使用重参数化技巧,将随机采样的过程转换为确定性函数的组合:

z=μ(x)+σ(x)ϵ其中 ϵN(0,I)z = \mu(x) + \sigma(x) \cdot \epsilon \quad \text{其中 } \epsilon \sim \mathcal{N}(0, I)
  • 解码器(Decoder):将潜在变量zz映射回数据空间,以生成新的样本:

    p(xz)=N(x;μ(z),σ2(z))p(x | z) = \mathcal{N}(x; \mu(z), \sigma^2(z))
  • 损失函数:VAE的损失函数由两部分组成:

    • 重构损失(Reconstruction Loss):衡量生成样本与真实样本的相似度。
    • KL 散度(Kullback-Leibler Divergence):衡量潜在分布与先验分布的差异。

    最终的损失函数为:

    L(x)=Eq(zx)[logp(xz)]DKL(q(zx)p(z))\mathcal{L}(x) = \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) || p(z))
  • VAE案例:手写数字生成

    让我们通过一个具体的案例来理解VAE的工作原理。我们将使用PyTorch库实现一个VAE,用于生成手写数字(MNIST数据集)。

    生成式 AI阅读地图卡查看大图
    生成式 AI阅读地图卡

    读《生成式AI基础概念:变分自编码器(VAE)》时,先确定要解决的场景,再把关键概念和练习动作串起来。这样读到细节时,不容易只记住零散名词。

    数据准备

    首先,确保安装好PyTorch和相关库。然后,我们可以加载MNIST数据集:

    import torch
    from torchvision import datasets, transforms
    
    # 数据变换
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    # 下载MNIST数据集
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=32, shuffle=True)
    

    VAE模型实现

    接下来,我们定义VAE模型:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class VAE(nn.Module):
        def __init__(self):
            super(VAE, self).__init__()
            self.fc1 = nn.Linear(28*28, 400)
            self.fc21 = nn.Linear(400, 20)  # 均值
            self.fc22 = nn.Linear(400, 20)  # 方差
            self.fc3 = nn.Linear(20, 400)
            self.fc4 = nn.Linear(400, 28*28)
    
        def encode(self, x):
            h1 = F.relu(self.fc1(x))
            return self.fc21(h1), self.fc22(h1)  # 均值和方差
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
    
        def decode(self, z):
            h3 = F.relu(self.fc3(z))
            return torch.sigmoid(self.fc4(h3))
    
        def forward(self, x):
            mu, logvar = self.encode(x.view(-1, 28*28))
            z = self.reparameterize(mu, logvar)
            return self.decode(z), mu, logvar
    

    训练模型

    训练VAE模型是优化损失函数的过程,这里我们使用Adam优化器:

    def loss_function(recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 28*28), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD
    
    model = VAE()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    model.train()
    for epoch in range(10):
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(torch.device("cpu"))
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            optimizer.step()
    

    生成新样本

    训练完成后,可以用VAE生成新的手写数字样本:

    model.eval()
    with torch.no_grad():
        z = torch.randn(64, 20)  # 从潜在空间中采样
        sample = model.decode(z).cpu()
    

    生成的样本可以通过Matplotlib可视化:

    import matplotlib.pyplot as plt
    
    sample = sample.view(-1, 28, 28)
    
    def show_samples(samples):
        fig, axes = plt.subplots(8, 8, figsize=(10, 10))
        for i, ax in enumerate(axes.flat):
            ax.imshow(samples[i].numpy(), cmap='gray')
            ax.axis('off')
        plt.show()
    
    show_samples(sample)
    
    生成式AI基础概念:变分自编码器(VAE)应用复盘卡查看大图
    生成式AI基础概念:变分自编码器(VAE)应用复盘卡

    如果《生成式AI基础概念:变分自编码器(VAE)》还没完全消化,可以从这张卡片的四个动作重新走一遍。

    生成式AI基础概念:变分自编码器(VAE)应用检查卡查看大图
    生成式AI基础概念:变分自编码器(VAE)应用检查卡

    回看《生成式AI基础概念:变分自编码器(VAE)》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。

    小结

    在本篇教程中,我们详细介绍了变分自编码器(VAE)的基本概念、工作原理及实现方法。通过手写数字生成的案例,我们发现VAE不仅

    相关教程

    相关入口

    AI 教程总索引

    分享文章

    转发到常用平台

    微信/朋友圈可先复制链接

    相关教程

    AI 教程总索引

    相关内容

    相关 AI 教程

    返回栏目

    Reader Messages

    读者留言

    有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

    最多 800 字

    为了防刷,每条留言会做长度、链接数量和提交频率限制。

    0/800

    留言列表

    0
    正在加载留言...