本教程将教你如何使用生成式对抗网络(GAN)中的一种变种——深度卷积生成对抗网络(DCGAN),来生成手写数字(MNIST数据集)。
小节项目案例 1:使用 DCGAN 生成手写数字
1. 引言
生成式对抗网络(GAN)是一种能够生成新数据的神经网络架构。本小节重点介绍如何构建和训练一个简单的 DCGAN 来生成手写数字。这项任务将基于经典的 MNIST 数据集。
2. 环境准备
在开始之前,请确保你的环境中安装了以下库:
1
| pip install tensorflow matplotlib numpy
|
3. 数据集准备
首先,我们需要加载 MNIST 数据集。可以使用 tf.keras.datasets
来加载:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| import tensorflow as tf import numpy as np import matplotlib.pyplot as plt
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0 x_train = (x_train - 0.5) * 2.0
x_train = np.expand_dims(x_train, axis=-1)
print(x_train.shape)
|
4. 创建生成器
生成器的任务是将随机噪声 z
转换为逼真的图像。在此示例中,我们使用几个卷积层和反卷积层来构建生成器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| def build_generator(): model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,))) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(tf.keras.layers.Reshape((7, 7, 256))) model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=1, padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=2, padding='same', use_bias=False)) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.ReLU()) model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=2, padding='same', use_bias=False, activation='tanh')) return model
generator = build_generator() generator.summary()
|
5. 创建判别器
判别器的任务是判断输入图像是真实的还是生成的。我们也使用几个卷积层来构建判别器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| def build_discriminator(): model = tf.keras.Sequential() model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=2, padding='same', input_shape=[28, 28, 1])) model.add(tf.keras.layers.LeakyReLU(alpha=0.2)) model.add(tf.keras.layers.Dropout(0.3)) model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=2, padding='same')) model.add(tf.keras.layers.LeakyReLU(alpha=0.2)) model.add(tf.keras.layers.Dropout(0.3)) model.add(tf.keras.layers.Flatten()) model.add(tf.keras.layers.Dense(1))
return model
discriminator = build_discriminator() discriminator.summary()
|
6. 损失函数和优化器
DCGAN 使用的损失函数是对抗损失。我们将使用 BinaryCrossentropy
作为损失函数,优化器使用 Adam。
1 2 3 4 5 6 7 8 9 10
| cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output): real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) total_loss = real_loss + fake_loss return total_loss
def generator_loss(fake_output): return cross_entropy(tf.ones_like(fake_output), fake_output)
|
7. 训练循环
接下来,我们需要编写训练循环。在训练中,生成器和判别器会相互竞争。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
| EPOCHS = 50 BATCH_SIZE = 256 NOISE_DIM = 100
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(BATCH_SIZE)
generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
@tf.function def train_step(images): noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(EPOCHS): for epoch in range(EPOCHS): for image_batch in train_dataset: train_step(image_batch)
generate_and_save_images(generator, epoch)
def generate_and_save_images(model, epoch): noise = tf.random.normal([16, NOISE_DIM]) generated_images = model(noise, training=False) generated_images = (generated_images + 1) / 2.0
plt.figure(figsize=(4, 4)) for i in range(generated_images.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow(generated_images[i, :, :, 0], cmap='gray') plt.axis('off') plt.savefig(f'gan_generated_epoch_{epoch}.png') plt.show()
train(EPOCHS)
|
8. 结果与总结
运行以上代码之后,你会看到