郭震 AI公众号:郭震AI

6 GAN的基本原理:对抗训练的流程

发布日期:

最近更新:

分类: GAN网络从零教程

预计阅读: 4 分钟

阅读次数: 0

预计阅读4 分钟
结构重点4 个
图文要点6 张
正文规模1.6k 字
GAN的基本原理:对抗训练的流程结构图查看大图
GAN的基本原理:对抗训练的流程结构图

GAN 的关键是生成器和判别器互相推动,学习时要同时看结构、训练和样本质量。阅读时可以按「对抗训练的基本概念 -> 对抗训练的流程 -> 案例说明 -> 生成器」建立结构,再回到正文里的代码、案例或指标做验证。

GAN的基本原理:对抗训练的流程核对图查看大图
GAN的基本原理:对抗训练的流程核对图

读完后,用一个真实小任务复查:输入是什么,处理环节在哪里,输出是否可验收;失败时先查「对抗训练的基本概念」,再查「对抗训练的流程」。

在上一篇中,我们探讨了GAN的损失函数的定义,了解了它们是如何影响模型的训练过程的。在本篇中,我们将深入分析GAN的核心机制——对抗训练的流程。通过对抗训练,生成器和判别器在不断的互动中提升各自的能力,从而实现生成真实感极强的数据。让我们详细展开这一过程,并结合案例进行说明。

对抗训练的基本概念

生成对抗网络(GAN)由两个神经网络构成:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗的方式进行训练,形成一个博弈过程。在这个过程中,生成器负责生成假样本,而判别器则负责判断样本是真实样本还是生成样本。

对抗训练的目标是在于平衡这两个网络的能力,生成器希望能够生成越来越真实的数据;而判别器则希望能够识别出这些假样本。这个过程是动态的,而且是迭代进行的,随着训练的进行,这两个网络会不断调整自己的策略。

对抗训练的流程

对抗训练的基本流程如下:

  1. 初始化网络

    1. 初始化生成器G和判别器D的网络参数。
  2. 真实样本选择

    1. 从真实数据集中随机选择一批真实样本,记为 x_real
  3. 生成假样本

    1. 通过生成器G生成一批假样本,输入为噪声向量 z,即 x_fake = G(z)
  4. 训练判别器

    1. 判别器D接受真实样本和生成的假样本。
    2. 计算判别器对真实样本的预测概率 D(x_real) 和对假样本的预测概率 D(x_fake)
    3. 计算判别器损失:LD=ExPdata[logD(x)]ExPg[log(1D(x))]L_D = - \mathbb{E}_{x \sim P_{data}}[\log D(x)] - \mathbb{E}_{x \sim P_{g}}[\log (1 - D(x))]
    4. 通过反向传播更新判别器D的参数,以最小化损失LDL_D
  • 训练生成器

    1. 生成器G生成一批新的假样本 x_fake = G(z)
    2. 计算生成器的损失:LG=ExPg[logD(G(z))]L_G = - \mathbb{E}_{x \sim P_{g}}[\log D(G(z))]
    3. 通过反向传播更新生成器G的参数,以最小化损失LGL_G
  • 循环迭代

    1. 重复第2步到第5步,直到达到预设的训练轮次或生成样本的质量达到要求。
  • 案例说明

    为了便于理解上述流程,我们来看一个具体的代码示例。这里以TensorFlow/Keras为基础实现一个简单的GAN模型。

    GAN对抗训练流程判断卡查看大图
    GAN对抗训练流程判断卡

    理解对抗训练流程时,先看生成器产出样本、判别器给出反馈、参数如何更新。循环稳定才可能生成好结果。

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras import layers
    
    # 生成器
    def build_generator(latent_dim):
        model = tf.keras.Sequential()
        model.add(layers.Dense(128, activation='relu', input_dim=latent_dim))
        model.add(layers.Dense(784, activation='sigmoid'))  # 输出28*28的扁平化图像
        return model
    
    # 判别器
    def build_discriminator(input_shape):
        model = tf.keras.Sequential()
        model.add(layers.Dense(128, activation='relu', input_shape=input_shape))
        model.add(layers.Dense(1, activation='sigmoid'))  # 二元分类
        return model
    
    # GAN模型
    latent_dim = 100
    generator = build_generator(latent_dim)
    discriminator = build_discriminator((784,))
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    # GAN组件
    discriminator.trainable = False
    gan_input = layers.Input(shape=(latent_dim,))
    fake_img = generator(gan_input)
    gan_output = discriminator(fake_img)
    gan_model = tf.keras.Model(gan_input, gan_output)
    gan_model.compile(loss='binary_crossentropy', optimizer='adam')
    
    # 训练过程
    def train_gan(epochs, batch_size):
        for epoch in range(epochs):
            # 生成假样本
            noise = np.random.normal(0, 1, size=[batch_size, latent_dim])
            generated_images = generator.predict(noise)
    
            # 真实样本样本
            real_images = ...  # 从真实数据集中加载
    
            # 标签
            real_labels = np.ones(batch_size)
            fake_labels = np.zeros(batch_size)
    
            # 训练判别器
            d_loss_real = discriminator.train_on_batch(real_images, real_labels)
            d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
            # 训练生成器
            noise = np.random.normal(0, 1, size=[batch_size, latent_dim])
            g_loss = gan_model.train_on_batch(noise, real_labels)  # 使用真实的标签来欺骗判别器
    
            if epoch % 1000 == 0:
                print(f'Epoch {epoch}, D Loss: {d_loss[0]}, G Loss: {g_loss}')
    
    # 设置训练参数并开始训练
    train_gan(epochs=10000, batch_size=32)
    

    在本示例中,我们构建了一个简单的生成器和判别器,并展示了基于对抗训练的基本训练流程。生成器负责从随机噪声生成图像,判别器则评估这些图像的真实性,并更新生成器的输出效果。

    GAN的基本原理:对抗训练的流程应用复盘卡查看大图
    GAN的基本原理:对抗训练的流程应用复盘卡

    学完《GAN的基本原理:对抗训练的流程》后,不妨换一个自己的场景试一次,重点观察输入、处理和输出是否能对应起来。

    GAN的基本原理:对抗训练的流程应用检查卡查看大图
    GAN的基本原理:对抗训练的流程应用检查卡

    如果想把《GAN的基本原理:对抗训练的流程》用到自己的任务里,可以先缩小场景,只验证一个最关键的判断点。

    小结

    在本篇中,我们详细阐述了对抗训练的流程,并结合代码示例使得这些概念更加具体化。对抗训练是GAN的核心,通过生成器和判别器的博弈,让模型逐渐学会生成更为真实的样本。在下一篇中,我们将一起设置环境和依赖,以便开始构建我们的第一个GAN模型。 生成对抗网络阅读地图卡

    读完《GAN的基本原理:对抗训练的流程》不要只停在“看懂了”。回头挑一个步骤动手做一遍,再记录哪里卡住,后面的学习会更稳。

    相关教程

    相关入口

    AI 教程总索引

    分享文章

    转发到常用平台

    微信/朋友圈可先复制链接

    相关教程

    AI 教程总索引

    相关内容

    相关 AI 教程

    返回栏目

    Reader Messages

    读者留言

    有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

    最多 800 字

    为了防刷,每条留言会做长度、链接数量和提交频率限制。

    0/800

    留言列表

    0
    正在加载留言...