15 GAN中的CNN结构详解
系列进度
AI 30 个神经网络 · 第 15 / 62 篇
GAN 是两个网络互相较劲:生成器负责骗过判别器,判别器负责找出破绽。真正的难点通常是训练稳定性。这篇重点看结构。先把数据流、关键模块和输出层画清楚,再回头看公式或代码。
我会同时看生成样本、判别器损失和样本多样性。只看 loss 容易误判 GAN 是否真的变好了。
在上一篇中,我们探讨了Faster R-CNN在目标检测中的应用案例。本篇将深入研究生成对抗网络(GAN)中的卷积神经网络(CNN)结构。理解这两者之间的关系及各自的功能,将有助于我们更好地掌握下一篇中将要讨论的GAN的实际应用实例。
GAN的基础概念
生成对抗网络(Generative Adversarial Networks,GAN)由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的任务是生成逼真的数据样本,而判别器则旨在区分真实样本和生成样本。
在大多数应用中,生成器和判别器都采用卷积神经网络(CNN)作为其基础结构。这是因为CNN擅长处理图像数据,非常适合用于图像生成与辨识任务。
CNN在GAN中的应用
1. 生成器的CNN结构
生成器通常使用反卷积(或转置卷积)来逐步将一个低维的随机噪声向量(通常是从正态分布中随机采样的矢量)转换为高维的图像。在这个过程中,生成器可能会包含如下层:
-
输入层:接收随机噪声向量,通常维度较小,例如:
z ~ N(0, 1),这个向量可能是100维的。 -
反卷积层:使用
Transpose Convolution(转置卷积)进行上采样,逐步增加特征图的大小,同时改变通道数。 -
激活函数:通常使用
ReLU函数,除了最后一层使用的tanh(为了将生成的图像标准化到[-1, 1])。 -
批量归一化:在每层中加入
Batch Normalization,以稳定训练过程,加速收敛。
这里是一个简单的生成器的构建示例代码:
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, z_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(z_dim, 128, 4, 1, 0, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
2. 判别器的CNN结构
判别器的结构通常是一个标准的卷积网络,由下采样(卷积层 + 池化层)构成,用于提取特征并做出分类决策。其结构包括:
-
卷积层:使用标准卷积层来逐渐减少特征图的维度,同时增加通道数。
-
激活函数:通常使用
Leaky ReLU以减少在训练时出现“死亡神经元”的风险。 -
全连接层:最终将特征图展平,并通过全连接层输出一个标量,用来判断输入来源。
以下是判别器的示例代码:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
结合案例:图像生成
在实践中,我们可以利用GAN生成高质量的图像。例如,DCGAN(Deep Convolutional GAN)是一种流行的变体,采用了上述的CNN结构,专门用于图像生成。其通过训练生成器生成手写数字(MNIST数据集)或人脸(CelebA数据集)等图像。
理解 GAN 中的 CNN 结构时,先看生成器如何上采样,判别器如何提取特征,再比较卷积核、步幅和归一化。
具体地,训练过程一般包括如下步骤:
- 使用随机噪声输入生成器,生成图像。
- 将真实图像和生成的图像输入判别器,计算损失。
- 更新生成器和判别器的参数,优化其性能。
如果《GAN中的CNN结构详解》还没完全消化,可以从这张卡片的四个动作重新走一遍。
回看《GAN中的CNN结构详解》时,不必一次做大项目,先用一条简单样例确认主线是否清楚。
小结
在这一篇中,我们详细探讨了GAN中的CNN结构,涵盖生成器和判别器的设计理念以及具体的实现代码。理解这一基础知识对于深入后续的GAN应用实例是至关重要的。在下一篇中,我们将具体探讨GAN在图像转换、风格迁移等应用中的实例,希望你能对这一前沿技术有更深入的了解。
读《GAN中的CNN结构详解》时,可以把配图当成路线卡:先看整体顺序,再看每一步为什么这样做,最后再检查边界条件。
相关教程
相关入口
分享文章
转发到常用平台
微信/朋友圈可先复制链接
相关教程
从相近问题继续读
相关内容