在上一篇文章中,我们介绍了深度学习的基本概念和其广泛应用。今天,我们将深入探讨生成对抗网络(Generative Adversarial Networks,简称GAN),这是一种近年来在生成模型领域引起广泛关注的方法。GAN的核心思想是通过对抗过程实现高质量的数据生成。
GAN的基本结构
生成对抗网络由两个主要组成部分构成:生成器(Generator)和判别器(Discriminator)。它们之间的关系可以看作是一场“对抗游戏”。下面是这两个组件的简要介绍:
生成器: 生成器
是一个模型,它旨在生成逼真的样本。其输入通常是随机噪声(通常为服从均匀分布或正态分布的随机向量),输出是经过训练生成的图像或其他数据。
判别器: 判别器
也同样是一个模型,它的任务是区分输入样本是真实的(来自真实数据集)还是由生成器生成的(伪造样本)。通过不断训练,判别器能够提升其辨别能力。
GAN的训练过程
GAN的训练过程可以概述为以下几个步骤:
初始化生成器和判别器: 同时设定两个模型的初始参数。
训练循环:
对抗过程: 两个模型进行“竞争”,生成器试图改善生成样本的质量,而判别器则努力提高识别伪造样本的能力。
最终,经过多次迭代,生成器生成的样本越来越接近真实样本,从而实现成功的生成。
数学推导
在数学层面,GAN的优化目标是通过最小化下面的损失函数来实现的:
$$
\min_G \max_D \mathbb{E}{x \sim P{data}(x)}[\log D(x)] + \mathbb{E}{z \sim P{z}(z)}[\log(1 - D(G(z)))]
$$
其中:
- $D(x)$ 表示判别器网络对真实样本 $x$ 的预测概率。
- $D(G(z))$ 表示判别器对生成样本的预测概率。
- $G(z)$ 为生成器输出,$z$ 是输入的随机噪声。
通过对这两个网络进行交替优化,可以不断改善生成效果和判别能力。
案例:MNIST手写数字生成
为了更好地理解GAN,我们来看看一个实际的示例:使用GAN生成MNIST手写数字。
数据准备
首先,我们需要导入必要的库并准备训练数据:
1 2 3 4 5 6 7 8
| import numpy as np import matplotlib.pyplot as plt from keras.datasets import mnist
(X_train, _), (_, _) = mnist.load_data() X_train = X_train.astype(np.float32) / 255.0 X_train = np.expand_dims(X_train, axis=-1)
|
构建生成器和判别器
接下来,我们构建生成器和判别器模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| from keras.models import Sequential from keras.layers import Dense, Reshape, Flatten, LeakyReLU
def build_generator(): model = Sequential() model.add(Dense(256, input_dim=100)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(28 * 28 * 1, activation='tanh')) model.add(Reshape((28, 28, 1))) return model
def build_discriminator(): model = Sequential() model.add(Flatten(input_shape=(28, 28, 1))) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) return model
generator = build_generator() discriminator = build_discriminator()
|
训练GAN
然后我们设置GAN进行训练:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| from keras.optimizers import Adam
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.trainable = False gan_input = Sequential([generator, discriminator]) gan_input.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
def train_gan(epochs, batch_size): for epoch in range(epochs): idx = np.random.randint(0, X_train.shape[0], batch_size) real_images = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, 100)) fake_images = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1))) d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, 100)) g_loss = gan_input.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % 1000 == 0: print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
train_gan(epochs=30000, batch_size=32)
|
在每次训练时,我们将实时生成的图像输出