9 条件GAN的训练和评估
在之前的文章中,我们探讨了条件生成对抗网络(cGAN)的应用实例。为了更深入地了解cGAN的工作原理,本篇将着重讨论其训练和评估方法。在深度学习的实践中,训练过程的设计和评估标准的选择直接影响模型的质量和应用效果。因此,我们将详细分析如何有效训练cGAN以及如何评估其生成结果。
1. 条件GAN的训练
1.1 训练过程
cGAN的训练过程与传统GAN类似,但我们在生成器和判别器中引入了条件信息。下面,我们将以MNIST手写数字生成的示例来说明cGAN的训练步骤。
准备数据集:
首先,我们需要加载MNIST数据集,并将其转换为可以供模型使用的格式。我们将每个图像与其对应的标签相结合,以使得生成器能够根据标签生成特定的数字。构建生成器和判别器:
cGAN的生成器和判别器的构建需同时接收条件信息。例如,生成器将随机噪声和标签作为输入,判别器将图像和标签作为输入。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17from keras.layers import Input, Dense, Reshape, Concatenate
from keras.models import Model
def build_generator():
noise = Input(shape=(100,))
label = Input(shape=(10,))
model_input = Concatenate()([noise, label])
x = Dense(128)(model_input)
x = Reshape((4, 4, 8))(x)
return Model([noise, label], x)
def build_discriminator():
img = Input(shape=(28, 28, 1))
label = Input(shape=(10,))
model_input = Concatenate()([img, label])
x = Dense(128)(model_input)
return Model([img, label], x)定义损失和优化器:
在cGAN中,损失函数通常使用二元交叉熵(binary crossentropy)。同时,将生成器和判别器编译为可优化的模型。训练循环:
cGAN的训练循环包括以下步骤:- 随机选择一个标签;
- 生成随机噪声;
- 将噪声和标签输入生成器,生成伪样本;
- 真实样本与伪样本一起喂入判别器进行训练。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21for epoch in range(num_epochs):
for _ in range(batch_count):
# 随机选择一个标签
random_indices = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[random_indices]
labels = y_train[random_indices]
# 生成随机噪声
noise = np.random.normal(0, 1, (batch_size, 100))
generated_images = generator.predict([noise, labels])
# 生成标签one-hot编码
real_labels = np.zeros((batch_size, 1))
fake_labels = np.ones((batch_size, 1))
d_loss_real = discriminator.train_on_batch([real_images, labels], real_labels)
d_loss_fake = discriminator.train_on_batch([generated_images, labels], fake_labels)
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
valid_labels = np.ones((batch_size, 1))
g_loss = combined_model.train_on_batch([noise, labels], valid_labels)
1.2 训练中的技巧
- Label Smoothing:通过降低真实标签的值来增强判别器的稳定性。
- 样本平衡:确保从每个类中均匀选取样本,以减少数据偏差。
- 动态学习率:根据训练阶段动态调整学习率,优化训练效果。
2. 条件GAN的评估
评估生成模型的性能具有挑战性,特别是当生成数据与真实数据的质量和多样性都需要被考虑时。以下是几种评估方法:
2.1 可视化生成效果
最直接的方法是通过可视化生成的图像来评估其质量。在MNIST例子中,可以随机生成几个样本并展示:
1 | import matplotlib.pyplot as plt |
2.2 FID和IS指标
Fréchet Inception Distance (FID)
和Inception Score (IS)
是评估生成模型性能的常用指标。FID越低,表示生成样本与真实样本的相似度越高。IS则评估生成图像的多样性和质量。
实现FID的Python代码示例:
from scipy.linalg import sqrtm
def calculate_fid(real_images, generated_images):
# 假设real_images和generated_images的形状都为(num_samples, 28, 28, 1)
mu1, sigma1 = calculate_statistics(real_images)
mu2, sigma2 = calculate_statistics(generated_images)
fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
return fid_value
def calculate_statistics(images):
# 计算均值和协方差矩阵
mu = np.mean(images, axis=0)
sigma = np.cov(images, rowvar=False)
return mu, sigma
def calculate_frechet_distance(mu1, sigma
9 条件GAN的训练和评估