Jupyter AI

6 生成对抗网络(GAN)基础概念

📅 发表日期: 2024年8月10日

分类: 🧠生成式 AI 教程

👁️阅读: --

在上一篇文章中,我们介绍了深度学习的基本概念和其广泛应用。今天,我们将深入探讨生成对抗网络(Generative Adversarial Networks,简称GAN),这是一种近年来在生成模型领域引起广泛关注的方法。GAN的核心思想是通过对抗过程实现高质量的数据生成。

GAN的基本结构

生成对抗网络由两个主要组成部分构成:生成器(Generator)和判别器(Discriminator)。它们之间的关系可以看作是一场“对抗游戏”。下面是这两个组件的简要介绍:

  1. 生成器: 生成器是一个模型,它旨在生成逼真的样本。其输入通常是随机噪声(通常为服从均匀分布或正态分布的随机向量),输出是经过训练生成的图像或其他数据。

  2. 判别器: 判别器也同样是一个模型,它的任务是区分输入样本是真实的(来自真实数据集)还是由生成器生成的(伪造样本)。通过不断训练,判别器能够提升其辨别能力。

GAN的训练过程

GAN的训练过程可以概述为以下几个步骤:

  1. 初始化生成器和判别器: 同时设定两个模型的初始参数。

  2. 训练循环:

    • 训练判别器: 不断地使用真实样本和生成样本,训练判别器去辨别它们。对于真实样本,判别器的目标输出为1,而对于生成的样本,目标输出为0。

    • 训练生成器: 利用当前的判别器输出,调整生成器的权重。生成器希望生成样本能够“欺骗”判别器,使其判断为真实样本。其目标是最大化判别器对生成样本的输出。

  3. 对抗过程: 两个模型进行“竞争”,生成器试图改善生成样本的质量,而判别器则努力提高识别伪造样本的能力。

最终,经过多次迭代,生成器生成的样本越来越接近真实样本,从而实现成功的生成。

数学推导

在数学层面,GAN的优化目标是通过最小化下面的损失函数来实现的:

minGmaxDExPdata(x)[logD(x)]+EzPz(z)[log(1D(G(z)))]\min_G \max_D \mathbb{E}_{x \sim P_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim P_{z}(z)}[\log(1 - D(G(z)))]

其中:

  • D(x)D(x) 表示判别器网络对真实样本 xx 的预测概率。
  • D(G(z))D(G(z)) 表示判别器对生成样本的预测概率。
  • G(z)G(z) 为生成器输出,zz 是输入的随机噪声。

通过对这两个网络进行交替优化,可以不断改善生成效果和判别能力。

案例:MNIST手写数字生成

为了更好地理解GAN,我们来看看一个实际的示例:使用GAN生成MNIST手写数字。

数据准备

首先,我们需要导入必要的库并准备训练数据:

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist

# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.astype(np.float32) / 255.0  # 归一化处理
X_train = np.expand_dims(X_train, axis=-1)  # 扩展维度

构建生成器和判别器

接下来,我们构建生成器和判别器模型:

from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, LeakyReLU

# 生成器
def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

# 判别器
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    return model

generator = build_generator()
discriminator = build_discriminator()

训练GAN

然后我们设置GAN进行训练:

from keras.optimizers import Adam

# 编译判别器
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# 生成对抗网络
discriminator.trainable = False
gan_input = Sequential([generator, discriminator])
gan_input.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# 训练过程
def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        # 选择随机的真实图像
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_images = X_train[idx]

        # 生成假图像
        noise = np.random.normal(0, 1, (batch_size, 100))
        fake_images = generator.predict(noise)

        # 训练判别器
        d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan_input.train_on_batch(noise, np.ones((batch_size, 1)))

        # 每1000次迭代打印损失情况
        if epoch % 1000 == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

train_gan(epochs=30000, batch_size=32)

在每次训练时,我们将实时生成的图像输出