6 生成对抗网络(GAN)基础概念
在上一篇文章中,我们介绍了深度学习的基本概念和其广泛应用。今天,我们将深入探讨生成对抗网络(Generative Adversarial Networks,简称GAN),这是一种近年来在生成模型领域引起广泛关注的方法。GAN的核心思想是通过对抗过程实现高质量的数据生成。
GAN的基本结构
生成对抗网络由两个主要组成部分构成:生成器(Generator)和判别器(Discriminator)。它们之间的关系可以看作是一场“对抗游戏”。下面是这两个组件的简要介绍:
-
生成器:
生成器
是一个模型,它旨在生成逼真的样本。其输入通常是随机噪声(通常为服从均匀分布或正态分布的随机向量),输出是经过训练生成的图像或其他数据。 -
判别器:
判别器
也同样是一个模型,它的任务是区分输入样本是真实的(来自真实数据集)还是由生成器生成的(伪造样本)。通过不断训练,判别器能够提升其辨别能力。
GAN的训练过程
GAN的训练过程可以概述为以下几个步骤:
-
初始化生成器和判别器: 同时设定两个模型的初始参数。
-
训练循环:
-
训练判别器: 不断地使用真实样本和生成样本,训练判别器去辨别它们。对于真实样本,判别器的目标输出为1,而对于生成的样本,目标输出为0。
-
训练生成器: 利用当前的判别器输出,调整生成器的权重。生成器希望生成样本能够“欺骗”判别器,使其判断为真实样本。其目标是最大化判别器对生成样本的输出。
-
-
对抗过程: 两个模型进行“竞争”,生成器试图改善生成样本的质量,而判别器则努力提高识别伪造样本的能力。
最终,经过多次迭代,生成器生成的样本越来越接近真实样本,从而实现成功的生成。
数学推导
在数学层面,GAN的优化目标是通过最小化下面的损失函数来实现的:
其中:
- 表示判别器网络对真实样本 的预测概率。
- 表示判别器对生成样本的预测概率。
- 为生成器输出, 是输入的随机噪声。
通过对这两个网络进行交替优化,可以不断改善生成效果和判别能力。
案例:MNIST手写数字生成
为了更好地理解GAN,我们来看看一个实际的示例:使用GAN生成MNIST手写数字。
数据准备
首先,我们需要导入必要的库并准备训练数据:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.astype(np.float32) / 255.0 # 归一化处理
X_train = np.expand_dims(X_train, axis=-1) # 扩展维度
构建生成器和判别器
接下来,我们构建生成器和判别器模型:
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, LeakyReLU
# 生成器
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(28 * 28 * 1, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model
# 判别器
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
return model
generator = build_generator()
discriminator = build_discriminator()
训练GAN
然后我们设置GAN进行训练:
from keras.optimizers import Adam
# 编译判别器
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
# 生成对抗网络
discriminator.trainable = False
gan_input = Sequential([generator, discriminator])
gan_input.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
# 训练过程
def train_gan(epochs, batch_size):
for epoch in range(epochs):
# 选择随机的真实图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]
# 生成假图像
noise = np.random.normal(0, 1, (batch_size, 100))
fake_images = generator.predict(noise)
# 训练判别器
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
g_loss = gan_input.train_on_batch(noise, np.ones((batch_size, 1)))
# 每1000次迭代打印损失情况
if epoch % 1000 == 0:
print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
train_gan(epochs=30000, batch_size=32)
在每次训练时,我们将实时生成的图像输出