31 从零到上手系统学习 PyTorch - 生成对抗网络 (GAN)

31 从零到上手系统学习 PyTorch - 生成对抗网络 (GAN)

1. 引言

生成对抗网络(GAN, Generative Adversarial Network)是由 Ian Goodfellow 等人在 2014 年提出的一种深度学习模型。GAN 由两部分组成:生成器(Generator)和判别器(Discriminator),它们通过对抗学习进行训练。生成器负责生成与真实数据相似的假数据,而判别器负责判断输入的样本是真实数据还是生成的数据。

2. GAN 的基本原理

GAN 的训练过程可以看作一个 零和博弈

  • 生成器:试图生成能够欺骗判别器的逼真数据。
  • 判别器:试图区分真实数据和生成的数据。

GAN 的目标是找到一个能够生成真实数据的生成器,通常用 minimax 问题来表述:

$$
\min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$

其中:

  • D(x) 是判别器对真实数据 x 的预测。
  • G(z) 是生成器生成的假数据,z 是从潜在空间随机采样的噪声。

3. GAN 的结构

3.1 生成器(Generator)

生成器通常是一个神经网络,它接受一个随机噪声向量作为输入,生成一个数据样本。例如,对于图像生成,生成器可能接收一个随机向量并输出一张图像。

3.2 判别器(Discriminator)

判别器也是一个神经网络,它接受一个数据样本(可以是真实数据或生成的数据)并输出一个概率值,表示该样本为真实的概率。

4. PyTorch 实现 GAN

下面我们将用 PyTorch 实现一个简单的 GAN,以生成手写数字图像(MNIST 数据集)。

4.1 环境准备

确保你已经安装了 PyTorch 和其他依赖库:

1
pip install torch torchvision matplotlib

4.2 数据准备

首先,我们需要从 torchvision 中加载 MNIST 数据集并进行预处理。

1
2
3
4
5
6
7
8
9
10
11
12
import torch
from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)

4.3 生成器和判别器模型

接下来,我们定义生成器和判别器的模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn

# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28), # MNIST 图像大小为 28x28
nn.Tanh() # 使用 Tanh 激活函数输出
)

def forward(self, z):
return self.model(z)

# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # 使用 Sigmoid 激活函数输出概率
)

def forward(self, x):
return self.model(x)

4.4 训练过程

我们将创建一个训练循环,使生成器和判别器交替训练。使用 BCELoss(二元交叉熵损失)作为损失函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch.optim as optim

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()

# 定义优化器
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# 定义损失函数
criterion = nn.BCELoss()

# 开始训练
num_epochs = 50
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1)

# 真实标签和虚假标签
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)

# 训练判别器
optimizer_D.zero_grad()
outputs = discriminator(real_images)
loss_D_real = criterion(outputs, real_labels)
loss_D_real.backward()

z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images.detach())
loss_D_fake = criterion(outputs, fake_labels)
loss_D_fake.backward()

optimizer_D.step()

# 训练生成器
optimizer_G.zero_grad()
outputs = discriminator(fake_images)
loss_G = criterion(outputs, real_labels)
loss_G.backward()

optimizer_G.step()

print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {loss_D_real.item() + loss_D_fake.item()}, g_loss: {loss_G.item()}')

# 你可以在这里保存生成的图像以观察生成进程

4.5 结果可视化

在训练过程中,我们可以定期查看生成器生成的图像,以观察生成质量是否提高。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt

# 生成并显示图像的函数
def show_images(generator, num_images=10):
z = torch.randn(num_images, 100)
generated_images = generator(z).view(-1, 1, 28, 28).detach()

plt.figure(figsize=(10, 10))
for i in range(num_images):
plt.subplot(1, num_images, i + 1)
plt.imshow(generated_images[i].squeeze(), cmap='gray')
plt.axis('off')
plt.show()

# 在训练的某个阶段调用 show_images()

5. 小结

GAN 是一种强大的生成模型,通过对抗训练,能够

31 从零到上手系统学习 PyTorch - 生成对抗网络 (GAN)

https://zglg.work/pytorch-tutorial/31/

作者

AI教程网

发布于

2024-08-07

更新于

2024-08-10

许可协议