9 条件GAN的训练和评估

在之前的文章中,我们探讨了条件生成对抗网络(cGAN)的应用实例。为了更深入地了解cGAN的工作原理,本篇将着重讨论其训练和评估方法。在深度学习的实践中,训练过程的设计和评估标准的选择直接影响模型的质量和应用效果。因此,我们将详细分析如何有效训练cGAN以及如何评估其生成结果。

1. 条件GAN的训练

1.1 训练过程

cGAN的训练过程与传统GAN类似,但我们在生成器和判别器中引入了条件信息。下面,我们将以MNIST手写数字生成的示例来说明cGAN的训练步骤。

  1. 准备数据集
    首先,我们需要加载MNIST数据集,并将其转换为可以供模型使用的格式。我们将每个图像与其对应的标签相结合,以使得生成器能够根据标签生成特定的数字。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    from keras.datasets import mnist
    import numpy as np

    # 加载数据
    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = X_train.astype('float32') / 255.0
    y_train = y_train.astype('float32')

    # 将数据集扩展为(样本,宽度,高度,通道)
    X_train = np.expand_dims(X_train, axis=-1)
  2. 构建生成器和判别器
    cGAN的生成器和判别器的构建需同时接收条件信息。例如,生成器将随机噪声和标签作为输入,判别器将图像和标签作为输入。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    from keras.layers import Input, Dense, Reshape, Concatenate
    from keras.models import Model

    def build_generator():
    noise = Input(shape=(100,))
    label = Input(shape=(10,))
    model_input = Concatenate()([noise, label])
    x = Dense(128)(model_input)
    x = Reshape((4, 4, 8))(x)
    return Model([noise, label], x)

    def build_discriminator():
    img = Input(shape=(28, 28, 1))
    label = Input(shape=(10,))
    model_input = Concatenate()([img, label])
    x = Dense(128)(model_input)
    return Model([img, label], x)
  3. 定义损失和优化器
    在cGAN中,损失函数通常使用二元交叉熵(binary crossentropy)。同时,将生成器和判别器编译为可优化的模型。

    1
    2
    3
    4
    5
    from keras.optimizers import Adam

    generator = build_generator()
    discriminator = build_discriminator()
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam())
  4. 训练循环
    cGAN的训练循环包括以下步骤:

    • 随机选择一个标签;
    • 生成随机噪声;
    • 将噪声和标签输入生成器,生成伪样本;
    • 真实样本与伪样本一起喂入判别器进行训练。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    for epoch in range(num_epochs):
    for _ in range(batch_count):
    # 随机选择一个标签
    random_indices = np.random.randint(0, X_train.shape[0], batch_size)
    real_images = X_train[random_indices]
    labels = y_train[random_indices]

    # 生成随机噪声
    noise = np.random.normal(0, 1, (batch_size, 100))
    generated_images = generator.predict([noise, labels])

    # 生成标签one-hot编码
    real_labels = np.zeros((batch_size, 1))
    fake_labels = np.ones((batch_size, 1))
    d_loss_real = discriminator.train_on_batch([real_images, labels], real_labels)
    d_loss_fake = discriminator.train_on_batch([generated_images, labels], fake_labels)

    # 训练生成器
    noise = np.random.normal(0, 1, (batch_size, 100))
    valid_labels = np.ones((batch_size, 1))
    g_loss = combined_model.train_on_batch([noise, labels], valid_labels)

1.2 训练中的技巧

  • Label Smoothing:通过降低真实标签的值来增强判别器的稳定性。
  • 样本平衡:确保从每个类中均匀选取样本,以减少数据偏差。
  • 动态学习率:根据训练阶段动态调整学习率,优化训练效果。

2. 条件GAN的评估

评估生成模型的性能具有挑战性,特别是当生成数据与真实数据的质量和多样性都需要被考虑时。以下是几种评估方法:

2.1 可视化生成效果

最直接的方法是通过可视化生成的图像来评估其质量。在MNIST例子中,可以随机生成几个样本并展示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import matplotlib.pyplot as plt

# 随机生成一些样本
noise = np.random.normal(0, 1, (10, 100))
labels = np.array([i for i in range(10)]).reshape(-1, 1)
labels = np.random.randint(0, 10, size=(10, 10)) # Random one-hot labels

generated_images = generator.predict([noise, labels])
plt.figure(figsize=(10, 10))
for i in range(10):
plt.subplot(5, 10, i + 1)
plt.imshow(generated_images[i].reshape(28, 28), cmap='gray')
plt.axis('off')
plt.show()

2.2 FID和IS指标

Fréchet Inception Distance (FID)Inception Score (IS)是评估生成模型性能的常用指标。FID越低,表示生成样本与真实样本的相似度越高。IS则评估生成图像的多样性和质量。

实现FID的Python代码示例:

from scipy.linalg import sqrtm

def calculate_fid(real_images, generated_images):
    # 假设real_images和generated_images的形状都为(num_samples, 28, 28, 1)
    mu1, sigma1 = calculate_statistics(real_images)
    mu2, sigma2 = calculate_statistics(generated_images)
    fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return fid_value

def calculate_statistics(images):
    # 计算均值和协方差矩阵
    mu = np.mean(images, axis=0)
    sigma = np.cov(images, rowvar=False)
    return mu, sigma

def calculate_frechet_distance(mu1, sigma

9 条件GAN的训练和评估

https://zglg.work/gans-advanced-one/9/

作者

IT教程网(郭震)

发布于

2024-08-15

更新于

2024-08-16

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论