27 从零生成式对抗网络 (GAN) 教程

27 从零生成式对抗网络 (GAN) 教程

本教程将教你如何使用生成式对抗网络(GAN)中的一种变种——深度卷积生成对抗网络(DCGAN),来生成手写数字(MNIST数据集)。

小节项目案例 1:使用 DCGAN 生成手写数字

1. 引言

生成式对抗网络(GAN)是一种能够生成新数据的神经网络架构。本小节重点介绍如何构建和训练一个简单的 DCGAN 来生成手写数字。这项任务将基于经典的 MNIST 数据集。

2. 环境准备

在开始之前,请确保你的环境中安装了以下库:

1
pip install tensorflow matplotlib numpy

3. 数据集准备

首先,我们需要加载 MNIST 数据集。可以使用 tf.keras.datasets 来加载:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 加载 MNIST 数据集
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()

# 归一化数据到 [-1, 1] 范围
x_train = x_train.astype('float32') / 255.0
x_train = (x_train - 0.5) * 2.0 # 转换到 [-1, 1]

# 扩展维度
x_train = np.expand_dims(x_train, axis=-1)

# 检查数据形状
print(x_train.shape) # (60000, 28, 28, 1)

4. 创建生成器

生成器的任务是将随机噪声 z 转换为逼真的图像。在此示例中,我们使用几个卷积层和反卷积层来构建生成器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def build_generator():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(tf.keras.layers.Reshape((7, 7, 256)))
model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=1, padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=2, padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=2, padding='same', use_bias=False, activation='tanh'))

return model

generator = build_generator()
generator.summary()

5. 创建判别器

判别器的任务是判断输入图像是真实的还是生成的。我们也使用几个卷积层来构建判别器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def build_discriminator():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=2, padding='same', input_shape=[28, 28, 1]))
model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=2, padding='same'))
model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
model.add(tf.keras.layers.Dropout(0.3))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1))

return model

discriminator = build_discriminator()
discriminator.summary()

6. 损失函数和优化器

DCGAN 使用的损失函数是对抗损失。我们将使用 BinaryCrossentropy 作为损失函数,优化器使用 Adam。

1
2
3
4
5
6
7
8
9
10
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss

def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)

7. 训练循环

接下来,我们需要编写训练循环。在训练中,生成器和判别器会相互竞争。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# 训练参数
EPOCHS = 50
BATCH_SIZE = 256
NOISE_DIM = 100

# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(BATCH_SIZE)

# 优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# 训练步骤
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)

real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)

gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

# 训练过程
def train(EPOCHS):
for epoch in range(EPOCHS):
for image_batch in train_dataset:
train_step(image_batch)

# 生成并保存图像
generate_and_save_images(generator, epoch)

# 生成并保存图像的函数
def generate_and_save_images(model, epoch):
noise = tf.random.normal([16, NOISE_DIM])
generated_images = model(noise, training=False)
generated_images = (generated_images + 1) / 2.0 # 还原到 [0, 1] 范围

plt.figure(figsize=(4, 4))
for i in range(generated_images.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.savefig(f'gan_generated_epoch_{epoch}.png')
plt.show()

# 开始训练
train(EPOCHS)

8. 结果与总结

运行以上代码之后,你会看到

27 从零生成式对抗网络 (GAN) 教程

https://zglg.work/gan-network-tutorial/27/

作者

AI教程网

发布于

2024-08-07

更新于

2024-08-10

许可协议