在上一篇中,我们讨论了改善 GAN 训练的模型架构变化,了解到不同架构设计在提升生成效果方面的重要性。今天,我们将深入探讨 GAN 的实际应用,特别是图像生成的案例。这一过程不仅体现了 GAN 的强大能力,同时也为我们实际应用 GAN 提供了宝贵的示例。
GAN简介
生成对抗网络(GAN)由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。生成器试图从随机噪声中生成逼真的图像,而判别器的任务是区分真实图像和生成的图像。通过对抗训练,生成器逐渐提高生成图像的质量。
在图像生成的过程中,我们通常会使用基于条件的 GAN(CGAN)或变分自编码器(VAE)等方法来赋予生成网络条件信息,以控制生成图像的特征。
图像生成案例
1. MNIST 手写数字生成
MNIST 数据集是一个经典的手写数字数据集,包含了 0 到 9 的手写数字。我们可以使用 GAN 来生成新的手写数字图像。
生成器与判别器架构
- 生成器:负责从随机噪声生成手写数字图像。输入为一维随机噪声向量,输出为一个 28x28 的图像。
- 判别器:负责区分输入图像是真实的 MNIST 图像还是生成的图像。
以下是一个简单的生成器和判别器的代码示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| import keras from keras.layers import Dense, Reshape, Flatten, Dropout, LeakyReLU from keras.models import Sequential import numpy as np
def build_generator(): model = Sequential() model.add(Dense(256, input_dim=100)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(28 * 28 * 1, activation='tanh')) model.add(Reshape((28, 28, 1))) return model
def build_discriminator(): model = Sequential() model.add(Flatten(input_shape=(28, 28, 1))) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.3)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dropout(0.3)) model.add(Dense(1, activation='sigmoid')) return model
|
2. 训练过程
GAN 的训练过程同时更新生成器和判别器。我们首先训练判别器,然后训练生成器。以下是训练代码的核心部分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| discriminator = build_discriminator() discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
generator = build_generator()
z = keras.Input(shape=(100,)) img = generator(z) discriminator.trainable = False validity = discriminator(img)
gan = keras.Model(z, validity) gan.compile(loss='binary_crossentropy', optimizer='adam')
for epoch in range(num_epochs): idx = np.random.randint(0, X_train.shape[0], batch_size) real_imgs = X_train[idx]
z = np.random.normal(0, 1, (batch_size, 100)) fake_imgs = generator.predict(z)
d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1))) d_loss_fake = discriminator.train_on_batch(fake_imgs, np.zeros((batch_size, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
z = np.random.normal(0, 1, (batch_size, 100)) g_loss = gan.train_on_batch(z, np.ones((batch_size, 1)))
print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
|
3. 生成图像展示
训练完成后,我们可以生成新的数字图像。每次输入不同的随机噪声,生成器都会输出相应的手写数字图像。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| import matplotlib.pyplot as plt
def generate_and_save_images(model, epoch, test_input): predictions = model.predict(test_input) plt.figure(figsize=(10, 10)) for i in range(25): plt.subplot(5, 5, i + 1) plt.imshow(predictions[i, :, :, 0], cmap='gray') plt.axis('off') plt.savefig(f"gan_generated_epoch_{epoch}.png") plt.show()
random_latent_vectors = np.random.normal(0, 1, (25, 100)) generate_and_save_images(generator, num_epochs, random_latent_vectors)
|
总结
在本章节中,我们通过实际案例展示了如何使用 GAN 进行图像生成。在下一篇中,我们将继续探索 GAN 的应用,特别是风格转移等技术,进一步扩展 GAN 在图像处理中的潜力。通过不断实践和学习,我们能够更好地掌握 GAN 的应用,推动计算机视觉领域的进步。