生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习框架,用于生成与真实数据具有相似分布的新数据。这一模型由 Ian Goodfellow 等人于 2014 年提出,已经在图像生成、图像修复、超分辨率等多个领域取得了显著成效。以下是对经典 GAN 模型的详细介绍。
1. GAN 的基本概念
在 GAN 中,有两个主要的组件:生成器(Generator)和判别器(Discriminator)。
以下是一个简单的 GAN 实现示例,使用 TensorFlow 和 Keras 库来生成手写数字(如 MNIST 数据集)。
4.1 导入库
1 2 3 4 5 6
import numpy as np import matplotlib.pyplot as plt from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, LeakyReLU from keras.optimizers import Adam
4.2 构建生成器和判别器
1 2 3 4 5 6 7 8 9 10 11 12 13
defbuild_generator(latent_dim): model = Sequential() model.add(Dense(128, input_dim=latent_dim)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(784, activation='tanh')) return model
defbuild_discriminator(input_shape): model = Sequential() model.add(Dense(128, input_shape=input_shape)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) return model
GAN 是一个强大的模型,能够生成与真实数据分布相似的新数据。通过生成器与判别器的对抗训练,GAN 可以学习到复杂的数据分布。尽管 GAN 具有很强的生成能力,仍然面临着一些挑战,例如模式崩溃(mode collapse)和不稳定的训练过程。因此,针对 GAN 的研究仍在不断深入,许多变种和改进算法也被提出以克服这些问题。
# 训练 DCGAN for epoch inrange(num_epochs): for i, (real_images, _) inenumerate(data_loader): batch_size = real_images.size(0) # 将真实图像放入 CUDA real_images = real_images.cuda()