1. 引言
生成对抗网络
(GAN, Generative Adversarial Network)是由 Ian Goodfellow
等人在 2014 年提出的一种深度学习模型。GAN 由两部分组成:生成器(Generator)和判别器(Discriminator),它们通过对抗学习进行训练。生成器负责生成与真实数据相似的假数据,而判别器负责判断输入的样本是真实数据还是生成的数据。
2. GAN 的基本原理
GAN 的训练过程可以看作一个 零和博弈
:
- 生成器:试图生成能够欺骗判别器的逼真数据。
- 判别器:试图区分真实数据和生成的数据。
GAN 的目标是找到一个能够生成真实数据的生成器,通常用 minimax
问题来表述:
$$
\min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$
其中:
D(x)
是判别器对真实数据 x
的预测。
G(z)
是生成器生成的假数据,z
是从潜在空间随机采样的噪声。
3. GAN 的结构
3.1 生成器(Generator)
生成器通常是一个神经网络,它接受一个随机噪声向量作为输入,生成一个数据样本。例如,对于图像生成,生成器可能接收一个随机向量并输出一张图像。
3.2 判别器(Discriminator)
判别器也是一个神经网络,它接受一个数据样本(可以是真实数据或生成的数据)并输出一个概率值,表示该样本为真实的概率。
4. PyTorch 实现 GAN
下面我们将用 PyTorch 实现一个简单的 GAN,以生成手写数字图像(MNIST 数据集)。
4.1 环境准备
确保你已经安装了 PyTorch 和其他依赖库:
1
| pip install torch torchvision matplotlib
|
4.2 数据准备
首先,我们需要从 torchvision
中加载 MNIST 数据集并进行预处理。
1 2 3 4 5 6 7 8 9 10 11 12
| import torch from torchvision import datasets, transforms
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
|
4.3 生成器和判别器模型
接下来,我们定义生成器和判别器的模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| import torch.nn as nn
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(100, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, 28 * 28), nn.Tanh() ) def forward(self, z): return self.model(z)
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(28 * 28, 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x)
|
4.4 训练过程
我们将创建一个训练循环,使生成器和判别器交替训练。使用 BCELoss
(二元交叉熵损失)作为损失函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| import torch.optim as optim
generator = Generator() discriminator = Discriminator()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
criterion = nn.BCELoss()
num_epochs = 50 for epoch in range(num_epochs): for i, (real_images, _) in enumerate(dataloader): batch_size = real_images.size(0) real_images = real_images.view(batch_size, -1)
real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1)
optimizer_D.zero_grad() outputs = discriminator(real_images) loss_D_real = criterion(outputs, real_labels) loss_D_real.backward()
z = torch.randn(batch_size, 100) fake_images = generator(z) outputs = discriminator(fake_images.detach()) loss_D_fake = criterion(outputs, fake_labels) loss_D_fake.backward()
optimizer_D.step()
optimizer_G.zero_grad() outputs = discriminator(fake_images) loss_G = criterion(outputs, real_labels) loss_G.backward()
optimizer_G.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {loss_D_real.item() + loss_D_fake.item()}, g_loss: {loss_G.item()}')
|
4.5 结果可视化
在训练过程中,我们可以定期查看生成器生成的图像,以观察生成质量是否提高。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| import matplotlib.pyplot as plt
def show_images(generator, num_images=10): z = torch.randn(num_images, 100) generated_images = generator(z).view(-1, 1, 28, 28).detach() plt.figure(figsize=(10, 10)) for i in range(num_images): plt.subplot(1, num_images, i + 1) plt.imshow(generated_images[i].squeeze(), cmap='gray') plt.axis('off') plt.show()
|
5. 小结
GAN 是一种强大的生成模型,通过对抗训练,能够