3 生成对抗网络基础回顾之GAN的损失函数

在上一篇文章中,我们讨论了生成对抗网络(GAN)的基本架构,包括其主要组成部分——生成器(Generator)和判别器(Discriminator)。在本篇文章中,我们将深入探讨GAN的损失函数。损失函数是GAN训练过程中的核心组成部分,它直接影响到模型的学习效果和生成样本的质量。接下来,我们将回顾GAN的损失函数的基本概念、各种损失函数的变体以及它们对模型性能的影响。

GAN的基本损失函数

在原始GAN的框架中,生成器和判别器通过一个对抗过程进行训练。其目标是生成器尽可能生成真实的样本,而判别器则努力区分真实样本与生成样本。其损失函数可以用如下公式表示:

$$
\text{min}G \text{max}D V(D, G) = \mathbb{E}{x \sim p{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z))]
$$

  • 其中,$D(x)$ 表示判别器对真实样本的输出,$G(z)$ 表示生成器生成的样本,$z$ 是从随机噪声分布中采样的。
  • $\mathbb{E}{x \sim p{\text{data}}}$ 表示对真实数据分布的期望,而 $\mathbb{E}_{z \sim p_z}$ 则是对生成器输入噪声分布的期望。

生成器和判别器的损失

在训练过程中,判别器需要最大化其损失函数,来更好地识别真实样本与生成样本。而生成器则需要最小化其损失函数,以生成更具“真实性”的样本。

损失函数解读

  • 当判别器的性能较差时,即 $D(x)$ 及 $D(G(z))$ 的输出较低,生成器的损失将会较低,因为生成器生成的样本能够欺骗判别器。
  • 当判别器的性能较好时,生成器生成的样本将无法被判别器所接受,从而导致其损失增加。

这就形成了一个动态对抗的过程。

损失函数的变体

随着GAN的发展,研究者们提出了多种损失函数变体,以解决原始GAN在训练过程中的不稳定性。如:

  1. **Wasserstein GAN (WGAN)**:
    WGAN通过使用Wasserstein距离替代JS散度,极大改善了模型训练的稳定性。其损失函数为:

    $$
    L_D = \mathbb{E}{x \sim p{\text{data}}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))]
    $$

    生成器的目标是最大化而非最小化,从而引导训练过程更平滑。

  2. **Least Squares GAN (LSGAN)**:
    LSGAN引入了最小二乘损失,使得生成器和判别器的输出更加接近于真实值。其损失函数为:

    $$
    L_D = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}}[(D(x) - 1)^2] + \frac{1}{2} \mathbb{E}_{z \sim p_z}[(D(G(z)))^2]
    $$

    这样做的好处在于,判别器的输出可以通过回归来学习,减小梯度消失的问题。

  3. Hinge Loss GAN
    在一些场景中,Hinge损失也被广泛使用,尤其是在图像生成任务中。其损失函数为:

    $$
    L_D = \mathbb{E}{x \sim p{\text{data}}}[\max(0, 1 - D(x))] + \mathbb{E}_{z \sim p_z}[\max(0, 1 + D(G(z)))]
    $$

    Hinge损失的形式使得模型更加鲁棒,特别是在样本不平衡的情况下。

实例代码

下面是一个使用PyTorch实现WGAN的简单示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn

class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Tanh(),
)

def forward(self, z):
return self.model(z)

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
)

def forward(self, x):
return self.model(x)

# 假设我们已实现WGAN的训练循环

小结

在本篇文章中,我们回顾了生成对抗网络的损失函数,并探讨了不同变体的特点及其带来的优势。理解GAN的损失函数对于提升生成模型的性能至关重要。在下一篇文章中,我们将继续探讨有关GAN的训练技巧,特别是如何在训练过程中实现稳定性和有效性。

通过这些内容,你将对生成对抗网络有更深入的理解,为后面的模块打下坚实的基础。

3 生成对抗网络基础回顾之GAN的损失函数

https://zglg.work/gans-advanced-one/3/

作者

IT教程网(郭震)

发布于

2024-08-15

更新于

2024-08-16

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论