1. 什么是 GAN?
生成对抗网络(GAN, Generative Adversarial Networks)是一种深度学习模型,由 Ian Goodfellow 在 2014 年提出。GAN 的主要目标是生成与真实样本相似的假样本。它由两个主要部分组成:生成器(Generator)和判别器(Discriminator)。
2. GAN 的组成部分
2.1 生成器(Generator)
生成器的任务是根据输入的随机噪声(通常是高维的随机向量)生成假样本。生成器试图捕捉数据的真实分布,使生成的样本看起来尽可能真实。
- 输入: 随机噪声向量
z
,通常从均匀分布或正态分布采样。 - 输出: 生成的假样本
G(z)
。
2.2 判别器(Discriminator)
判别器的任务是判断输入样本是来自真实数据集还是生成器生成的假样本。判别器给出的结果是一个概率值,表示样本为真实样本的概率。
- 输入: 实际样本
x
或生成的假样本G(z)
。 - 输出: 概率值
D(x)
,表示样本是真实的概率。
3. GAN 的训练过程
GAN 的训练过程可以看作是一个博弈(Game)过程,生成器和判别器相互竞争,试图提高自身的性能。
训练判别器: 根据真实样本和生成假样本,更新判别器的权重,以便它能够更好地区分真实和虚假的样本。训练目标是最大化以下目标函数:
1
D^* = \max_D \mathbb{E}_{x \sim P_{data}}[\log D(x)] + \mathbb{E}_{z \sim P_z}[\log (1 - D(G(z)))]
训练生成器: 更新生成器的权重,使得生成的假样本能够欺骗判别器,使其输出更高的概率值。训练目标是最小化以下目标函数:
1
G^* = \max_G \mathbb{E}_{z \sim P_z}[\log D(G(z))]
交替训练: GAN 的训练通常是先更新判别器,然后再更新生成器。这个过程会反复进行,直到生成器生成的样本与真实样本难以区分为止。
4. GAN 的损失函数
GAN 的损失函数是上述目标函数的具体表现。实现上通常使用二元交叉熵损失函数。
判别器损失:
1
D_loss = -torch.mean(torch.log(D(real_samples)) + torch.log(1 - D(G(noise))))
生成器损失:
1
G_loss = -torch.mean(torch.log(D(G(noise))))
5. GAN 的应用场景
- 图像生成:GAN 可以生成高质量的图像,例如人脸图像。
- 图像超分辨率:通过生成高分辨率的图像来提高低分辨率图像的质量。
- 图像到图像的转换:如图像风格迁移(例如,将白天的照片转换为夜晚)。
- 文本生成:生成自然语言文本(尽管效果不如 RNN 等模型)。
6. GAN 的示例代码
以下是一个简单的 GAN 训练框架示例,使用 PyTorch 库生成手写数字(MNIST 数据集)的样本:
1 | import torch |
7. 总结
生成对抗网络是当前深度学习领域非常活跃的研究方向之一。通过对抗训练,GAN 能够生成高质量的样本,因此在图像生成、风格转换、数据增强等领域有着广泛的应用。生成器和判别器