6 生成对抗网络(GAN)基础概念

在上一篇文章中,我们介绍了深度学习的基本概念和其广泛应用。今天,我们将深入探讨生成对抗网络(Generative Adversarial Networks,简称GAN),这是一种近年来在生成模型领域引起广泛关注的方法。GAN的核心思想是通过对抗过程实现高质量的数据生成。

GAN的基本结构

生成对抗网络由两个主要组成部分构成:生成器(Generator)和判别器(Discriminator)。它们之间的关系可以看作是一场“对抗游戏”。下面是这两个组件的简要介绍:

  1. 生成器: 生成器是一个模型,它旨在生成逼真的样本。其输入通常是随机噪声(通常为服从均匀分布或正态分布的随机向量),输出是经过训练生成的图像或其他数据。

  2. 判别器: 判别器也同样是一个模型,它的任务是区分输入样本是真实的(来自真实数据集)还是由生成器生成的(伪造样本)。通过不断训练,判别器能够提升其辨别能力。

GAN的训练过程

GAN的训练过程可以概述为以下几个步骤:

  1. 初始化生成器和判别器: 同时设定两个模型的初始参数。

  2. 训练循环:

    • 训练判别器: 不断地使用真实样本和生成样本,训练判别器去辨别它们。对于真实样本,判别器的目标输出为1,而对于生成的样本,目标输出为0。

    • 训练生成器: 利用当前的判别器输出,调整生成器的权重。生成器希望生成样本能够“欺骗”判别器,使其判断为真实样本。其目标是最大化判别器对生成样本的输出。

  3. 对抗过程: 两个模型进行“竞争”,生成器试图改善生成样本的质量,而判别器则努力提高识别伪造样本的能力。

最终,经过多次迭代,生成器生成的样本越来越接近真实样本,从而实现成功的生成。

数学推导

在数学层面,GAN的优化目标是通过最小化下面的损失函数来实现的:

$$
\min_G \max_D \mathbb{E}{x \sim P{data}(x)}[\log D(x)] + \mathbb{E}{z \sim P{z}(z)}[\log(1 - D(G(z)))]
$$

其中:

  • $D(x)$ 表示判别器网络对真实样本 $x$ 的预测概率。
  • $D(G(z))$ 表示判别器对生成样本的预测概率。
  • $G(z)$ 为生成器输出,$z$ 是输入的随机噪声。

通过对这两个网络进行交替优化,可以不断改善生成效果和判别能力。

案例:MNIST手写数字生成

为了更好地理解GAN,我们来看看一个实际的示例:使用GAN生成MNIST手写数字。

数据准备

首先,我们需要导入必要的库并准备训练数据:

1
2
3
4
5
6
7
8
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist

# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.astype(np.float32) / 255.0 # 归一化处理
X_train = np.expand_dims(X_train, axis=-1) # 扩展维度

构建生成器和判别器

接下来,我们构建生成器和判别器模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, LeakyReLU

# 生成器
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(28 * 28 * 1, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model

# 判别器
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
return model

generator = build_generator()
discriminator = build_discriminator()

训练GAN

然后我们设置GAN进行训练:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from keras.optimizers import Adam

# 编译判别器
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# 生成对抗网络
discriminator.trainable = False
gan_input = Sequential([generator, discriminator])
gan_input.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# 训练过程
def train_gan(epochs, batch_size):
for epoch in range(epochs):
# 选择随机的真实图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]

# 生成假图像
noise = np.random.normal(0, 1, (batch_size, 100))
fake_images = generator.predict(noise)

# 训练判别器
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# 训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
g_loss = gan_input.train_on_batch(noise, np.ones((batch_size, 1)))

# 每1000次迭代打印损失情况
if epoch % 1000 == 0:
print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

train_gan(epochs=30000, batch_size=32)

在每次训练时,我们将实时生成的图像输出

6 生成对抗网络(GAN)基础概念

https://zglg.work/gen-ai-tutorial/6/

作者

IT教程网(郭震)

发布于

2024-08-10

更新于

2024-08-11

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论