16 生成对抗网络(GAN)图像生成案例探索

在上一篇中,我们讨论了改善 GAN 训练的模型架构变化,了解到不同架构设计在提升生成效果方面的重要性。今天,我们将深入探讨 GAN 的实际应用,特别是图像生成的案例。这一过程不仅体现了 GAN 的强大能力,同时也为我们实际应用 GAN 提供了宝贵的示例。

GAN简介

生成对抗网络(GAN)由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。生成器试图从随机噪声中生成逼真的图像,而判别器的任务是区分真实图像和生成的图像。通过对抗训练,生成器逐渐提高生成图像的质量。

在图像生成的过程中,我们通常会使用基于条件的 GAN(CGAN)或变分自编码器(VAE)等方法来赋予生成网络条件信息,以控制生成图像的特征。

图像生成案例

1. MNIST 手写数字生成

MNIST 数据集是一个经典的手写数字数据集,包含了 0 到 9 的手写数字。我们可以使用 GAN 来生成新的手写数字图像。

生成器与判别器架构

  • 生成器:负责从随机噪声生成手写数字图像。输入为一维随机噪声向量,输出为一个 28x28 的图像。
  • 判别器:负责区分输入图像是真实的 MNIST 图像还是生成的图像。

以下是一个简单的生成器和判别器的代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import keras
from keras.layers import Dense, Reshape, Flatten, Dropout, LeakyReLU
from keras.models import Sequential
import numpy as np

# 生成器
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(28 * 28 * 1, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model

# 判别器
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
return model

2. 训练过程

GAN 的训练过程同时更新生成器和判别器。我们首先训练判别器,然后训练生成器。以下是训练代码的核心部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 编译模型
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

generator = build_generator()

# GAN模型
z = keras.Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)

gan = keras.Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer='adam')

# 训练过程
for epoch in range(num_epochs):
# 训练判别器
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]

z = np.random.normal(0, 1, (batch_size, 100)) # 随机噪声
fake_imgs = generator.predict(z)

d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(fake_imgs, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# 训练生成器
z = np.random.normal(0, 1, (batch_size, 100))
g_loss = gan.train_on_batch(z, np.ones((batch_size, 1)))

# 输出损失
print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

3. 生成图像展示

训练完成后,我们可以生成新的数字图像。每次输入不同的随机噪声,生成器都会输出相应的手写数字图像。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt

def generate_and_save_images(model, epoch, test_input):
predictions = model.predict(test_input)
plt.figure(figsize=(10, 10))
for i in range(25):
plt.subplot(5, 5, i + 1)
plt.imshow(predictions[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.savefig(f"gan_generated_epoch_{epoch}.png")
plt.show()

# 生成图像
random_latent_vectors = np.random.normal(0, 1, (25, 100))
generate_and_save_images(generator, num_epochs, random_latent_vectors)

总结

在本章节中,我们通过实际案例展示了如何使用 GAN 进行图像生成。在下一篇中,我们将继续探索 GAN 的应用,特别是风格转移等技术,进一步扩展 GAN 在图像处理中的潜力。通过不断实践和学习,我们能够更好地掌握 GAN 的应用,推动计算机视觉领域的进步。

16 生成对抗网络(GAN)图像生成案例探索

https://zglg.work/gan-network-tutorial/16/

作者

IT教程网(郭震)

发布于

2024-08-10

更新于

2024-08-10

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论