5 GAN的基本原理之损失函数的定义

在上一篇中,我们探讨了生成对抗网络(GAN)中生成器和判别器的角色。生成器的任务是生成尽可能真实的数据,而判别器则负责区分实际数据和生成数据的真假。在这一节中,我们将深入了解损失函数的定义,它是衡量生成器与判别器性能的核心。

损失函数的基本概念

在 GAN 中,损失函数用于优化生成器和判别器。我们需要定义损失函数,使两个网络相互竞争,从而提升生成器的生成能力和判别器的识别能力。

对抗损失函数

GAN 的核心思想是“对抗”。我们通过以下公式来定义对抗损失:

$$
\mathcal{L}(D, G) = \mathbb{E}{x \sim p{data}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]
$$

在这个公式中:

  • $D(x)$ 是判别器在真实数据 $x$ 上的输出。
  • $D(G(z))$ 是判别器在生成数据 $G(z)$ 上的输出。

这里,$D(x)$ 越接近 1,$D(G(z))$ 越接近 0,损失就越小,说明判别器能够很好地区分真实和生成的数据。

生成器的损失

生成器的目标是使判别器误以为生成的数据是真实的。因此,生成器的损失函数为:

$$
\mathcal{L}(G) = \mathbb{E}_{z \sim p_z(z)} [\log (D(G(z)))]
$$

在这个公式中,$G(z)$ 是生成器生成的数据。生成器的目标是最大化 $D(G(z))$,使判别器认为这些生成的数据是真实的。

最优解

在理论上,当 GAN 的训练达到平衡状态时,总损失函数 $ \mathcal{L}(D, G) $ 应该减少到 0:

  1. 判别器 $D$ 的输出来区分真实样本和生成样本都相等,即 $D(x) = 1/2$ 和 $D(G(z)) = 1/2$。
  2. 此时生成器能够生成非常逼真的样本,以至于判别器无法区分。

案例分析

请考虑一个简单的场景,我们使用 GAN 来生成手写数字图像(例如 MNIST 数据集)。在训练过程中,生成器试图生成手写数字图像,而判别器则试图区分真实的手写数字和生成的手写数字。

代码示例

以下是一个简单的 GAN 实现示例,演示如何定义损失函数并进行优化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torch.optim as optim

# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 28*28),
nn.Tanh()
)

def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)

# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)

def forward(self, img):
return self.model(img)

# 初始化网络
generator = Generator()
discriminator = Discriminator()

# 定义损失函数
criterion = nn.BCELoss()

# 定义优化器
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 假设 z 是从标准正态分布中随机采样的噪声
# 真实样本的标签是真实标签 1,生成样本的标签是假标签 0
z = torch.randn(64, 100)
real_samples = torch.randint(0, 2, (64, 1)).float() # 假设这是从真实数据集中提取的真实样本

# 判别器的损失
D_real = discriminator(real_samples)
D_fake = discriminator(generator(z))
loss_d = criterion(D_real, torch.ones_like(D_real)) + criterion(D_fake, torch.zeros_like(D_fake))

# 生成器的损失
loss_g = criterion(D_fake, torch.ones_like(D_fake)) # 生成器希望 D_fake 接近 1

# 更新判别器和生成器的参数
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()

optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()

在这个例子中,我们定义了生成器和判别器的结构,并使用二元交叉熵损失(BCE)作为损失函数。通过如下动作,生成器和判别器可以在训练过程中不断优化。

总结

本节我们详细讨论了 GAN 中损失函数的定义。我们了解了生成器和判别器如何通过对抗性损失进行优化,从而不断提升生成数据的质量。损失函数是 GAN 训练的核心,通过精心设计的损失函数,我们可以实现理想的对抗训练。在下一节中,我们将探讨 GAN 的对抗训练流程,深入分析如何应用这些损失函数来实现有效的训练。

5 GAN的基本原理之损失函数的定义

https://zglg.work/gan-network-tutorial/5/

作者

IT教程网(郭震)

发布于

2024-08-10

更新于

2024-08-10

许可协议

分享转发

交流

更多教程加公众号

更多教程加公众号

加入星球获取PDF

加入星球获取PDF

打卡评论