郭震 AI公众号:郭震AI

15 GAN中的CNN结构详解

发布日期:

最近更新:

分类: 30个神经网络

预计阅读: 4 分钟

阅读次数: 0

系列进度

AI 30 个神经网络 · 第 15 / 62

预计阅读4 分钟
结构重点6 个
图文要点6 张
正文规模1.6k 字
GAN中的CNN结构详解结构图查看大图
GAN中的CNN结构详解结构图

GAN 是两个网络互相较劲:生成器负责骗过判别器,判别器负责找出破绽。真正的难点通常是训练稳定性。这篇重点看结构。先把数据流、关键模块和输出层画清楚,再回头看公式或代码。

GAN中的CNN结构详解实操核对图查看大图
GAN中的CNN结构详解实操核对图

我会同时看生成样本、判别器损失和样本多样性。只看 loss 容易误判 GAN 是否真的变好了。

在上一篇中,我们探讨了Faster R-CNN在目标检测中的应用案例。本篇将深入研究生成对抗网络(GAN)中的卷积神经网络(CNN)结构。理解这两者之间的关系及各自的功能,将有助于我们更好地掌握下一篇中将要讨论的GAN的实际应用实例。

GAN的基础概念

生成对抗网络(Generative Adversarial Networks,GAN)由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的任务是生成逼真的数据样本,而判别器则旨在区分真实样本和生成样本。

在大多数应用中,生成器和判别器都采用卷积神经网络(CNN)作为其基础结构。这是因为CNN擅长处理图像数据,非常适合用于图像生成与辨识任务。

CNN在GAN中的应用

1. 生成器的CNN结构

生成器通常使用反卷积(或转置卷积)来逐步将一个低维的随机噪声向量(通常是从正态分布中随机采样的矢量)转换为高维的图像。在这个过程中,生成器可能会包含如下层:

  • 输入层:接收随机噪声向量,通常维度较小,例如:z ~ N(0, 1),这个向量可能是100维的。

  • 反卷积层:使用Transpose Convolution(转置卷积)进行上采样,逐步增加特征图的大小,同时改变通道数。

  • 激活函数:通常使用ReLU函数,除了最后一层使用的tanh(为了将生成的图像标准化到[-1, 1])。

  • 批量归一化:在每层中加入Batch Normalization,以稳定训练过程,加速收敛。

这里是一个简单的生成器的构建示例代码:

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

2. 判别器的CNN结构

判别器的结构通常是一个标准的卷积网络,由下采样(卷积层 + 池化层)构成,用于提取特征并做出分类决策。其结构包括:

  • 卷积层:使用标准卷积层来逐渐减少特征图的维度,同时增加通道数。

  • 激活函数:通常使用Leaky ReLU以减少在训练时出现“死亡神经元”的风险。

  • 全连接层:最终将特征图展平,并通过全连接层输出一个标量,用来判断输入来源。

以下是判别器的示例代码:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

结合案例:图像生成

在实践中,我们可以利用GAN生成高质量的图像。例如,DCGAN(Deep Convolutional GAN)是一种流行的变体,采用了上述的CNN结构,专门用于图像生成。其通过训练生成器生成手写数字(MNIST数据集)或人脸(CelebA数据集)等图像。

GAN中的CNN结构判断卡查看大图
GAN中的CNN结构判断卡

理解 GAN 中的 CNN 结构时,先看生成器如何上采样,判别器如何提取特征,再比较卷积核、步幅和归一化。

具体地,训练过程一般包括如下步骤:

  1. 使用随机噪声输入生成器,生成图像。
  2. 将真实图像和生成的图像输入判别器,计算损失。
  3. 更新生成器和判别器的参数,优化其性能。
GAN中的CNN结构详解应用复盘卡查看大图
GAN中的CNN结构详解应用复盘卡

如果《GAN中的CNN结构详解》还没完全消化,可以从这张卡片的四个动作重新走一遍。

GAN中的CNN结构详解应用检查卡查看大图
GAN中的CNN结构详解应用检查卡

回看《GAN中的CNN结构详解》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。

小结

在这一篇中,我们详细探讨了GAN中的CNN结构,涵盖生成器和判别器的设计理念以及具体的实现代码。理解这一基础知识对于深入后续的GAN应用实例是至关重要的。在下一篇中,我们将具体探讨GAN在图像转换、风格迁移等应用中的实例,希望你能对这一前沿技术有更深入的了解。 神经网络阅读地图卡

读《GAN中的CNN结构详解》时,可以把配图当成路线卡:先看整体顺序,再看每一步为什么这样做,最后再检查边界条件。

相关教程

相关入口

AI 教程总索引

分享文章

转发到常用平台

微信/朋友圈可先复制链接

相关教程

AI 教程总索引

相关内容

相关 AI 教程

返回栏目

Reader Messages

读者留言

有问题、补充资料或实测结果,可以直接留下。这里不需要登录。

最多 800 字

为了防刷,每条留言会做长度、链接数量和提交频率限制。

0/800

留言列表

0
正在加载留言...