10 WGAN (Wasserstein GAN) 详细教程

10 WGAN (Wasserstein GAN) 详细教程

1. 引言

生成对抗网络(GAN)是一种强大的生成模型,但在训练过程中常常会遭遇不稳定性和模式崩溃等问题。为了解决这些问题,Wetzerstein GAN(WGAN)有效地改进了GAN的训练方式,通过引入Wasserstein距离来度量真实数据和生成数据之间的差异。WGAN具有更稳定的训练过程和更好的生成效果。

2. WGAN 基础概念

2.1 Wasserstein 距离

  • Wasserstein 距离,又称地球移动者距离(Earth Mover’s Distance, EMD),用于量化两个概率分布之间的差异。
  • 与传统的 Jensen-Shannon Divergence 相比,Wasserstein 距离对生成分布的支持情况更加敏感,这使得在训练时模型能够更好地捕捉数据分布的变化。

2.2 作图

在训练过程中,WGAN追求的是:

$$
W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [| x - y |]
$$

其中,$\Pi(P_r, P_g)$ 是所有将真实分布 $P_r$ 和生成分布 $P_g$ 结合的联合分布的集合。

3. WGAN 的模型结构

WGAN 的结构与标准 GAN 类似,包括一个生成器(Generator, G)和一个判别器(Discriminator, D)。但是,WGAN 的判别器不再是一个概率输出模型,而是一个 Lipschitz 连续函数。为了满足这一条件,通常使用一个权重裁剪或梯度惩罚的方法。

3.1 生成器(G)

  • 输入:随机噪声 $\mathbf{z}$,通常从一个均匀分布或正态分布中采样。
  • 输出:生成样本 $\mathbf{G(z)}$。

3.2 判别器(D)

  • 输入:数据样本(真实样本或生成样本)。
  • 输出:一个实数值,用于反映样本的真实性。

4. WGAN 训练过程

4.1 损失函数

WGAN 的损失函数定义如下:

  • 对于生成器:

$$
L_G = -\mathbb{E}[D(G(z))]
$$

  • 对于判别器:

$$
L_D = \mathbb{E}[D(x)] - \mathbb{E}[D(G(z))] + \lambda \cdot \mathbb{E}[(|\nabla D(x) |_2 - 1)^2]
$$

其中,$\lambda$ 是梯度惩罚的超参数。

4.2 梯度惩罚

为了满足 Lipschitz 条件,而不是简单地裁剪权重,可以通过以下方式来加以限制:

  1. 从真实样本和生成样本中间随机采样,生成新的样本 $\hat{x} = \epsilon \cdot x + (1 - \epsilon) \cdot G(z)$,其中 $\epsilon$ 从均匀分布 $U(0, 1)$ 中采样。
  2. 计算判别器对 $\hat{x}$ 输入的梯度。
  3. 添加梯度惩罚项。

5. WGAN实例代码

以下是一个用 TensorFlow/Keras 实现 WGAN 的基本示例。

import tensorflow as tf
from tensorflow.keras import layers

# 生成器
def build_generator(z_dim):
    model = tf.keras.Sequential()
    model.add(layers.Dense(128, input_dim=z_dim))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(28 * 28 * 1, activation='tanh'))  # MNIST长度
    model.add(layers.Reshape((28, 28, 1)))
    return model

# 判别器
def build_critic(img_shape):
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=img_shape))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dense(1))  # 实数值
    return model

# 梯度惩罚
def gradient_penalty(critic, real_images, fake_images):
    batch_size = real_images.shape[0]
    epsilon = tf.random.uniform((batch_size, 1, 1, 1), 0.0, 1.0)
    interpolated_images = epsilon * real_images + (1 - epsilon) * fake_images

    with tf.GradientTape() as tape:
        tape.watch(interpolated_images)
        D_interpolated = critic(interpolated_images)

    gradients = tape.gradient(D_interpolated, interpolated_images)[0]
    gradient_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
    return tf.reduce_mean((gradient_norm - 1.0) ** 2)

# 示例训练过程
def train_wgan(generator, critic, epochs, batch_size):
    for epoch in range(epochs):
        for _ in range(5):  # 更新判别器更多次
            # 获取批次数据 (这里使用随机生成假数据)
            noise = tf.random.normal((batch_size, z_dim))
            real_images = ...  # 这里应该加载真实数据
            
            fake_images = generator(noise)
            with tf.GradientTape() as tape:
                D_real = critic(real_images)
                D_fake = critic(fake_images)
                gp = gradient_penalty(critic, real_images, fake_images)
                D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + 10 * gp  # 10是lambda
            grads = tape.gradient(D_loss, critic.trainable_variables)
            critic.optimizer.apply_gradients(zip(grads, critic.trainable_variables))

        # 更新生成器
        noise = tf.random.normal((batch_size, z_dim))
        with tf.GradientTape() as tape:
            fake_images = generator(noise)
            D_fake = critic(fake_images)
            G_loss = -tf.reduce_mean(D_fake)
        grads = tape.gradient(G_loss, generator.trainable_variables)
        generator.optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        print(f'Epoch:
11 WGAN-GP (Wasserstein GAN with Gradient Penalty) 详细教程

11 WGAN-GP (Wasserstein GAN with Gradient Penalty) 详细教程

简介

WGAN-GP 是一种改进的生成对抗网络(GAN)机制,旨在解决传统 GAN 中训练不稳定和模式崩塌的问题。它通过引入一个新的损失函数,该损失函数基于 Wasserstein 距离,并使用梯度惩罚来确保生成器和判别器的 Lipschitz 连续性。

1. Wasserstein 距离

传统 GAN 使用的是 JS 散度或 Kullback-Leibler 散度,WGAN 采用Wasserstein 距离(Earth Mover’s Distance),它为模型提供了更好的训练信号。

Wasserstein 距离的优点

  • 提供直接的优化目标
  • 训练过程稳定且更易于收敛
  • 有效缓解了模式崩塌

2. WGAN-GP 的基本原理

2.1 损失函数

在 WGAN-GP 中,生成器 G 的目标是最小化以下损失函数:

1
L_G = -D(G(z))

判别器 D 的目标是最大化以下损失函数:

1
L_D = D(real_data) - D(G(z)) + λ * R(D)

其中 R(D) 是关于 D 的正则化项,λ 是超参数。

2.2 梯度惩罚

为了确保判别器 D 是 Lipschitz 连续的,WGAN-GP 引入了梯度惩罚。具体而言,梯度惩罚的形式为:

1
R(D) = E[(||∇D(θ)||_2 - 1)²]

这里,θ 是在 D 的采样的真实数据与生成数据混合后的线性插值数据。

3. 算法步骤

3.1 网络架构

通常,WGAN-GP 使用两个神经网络:

  • 生成网络 G:用于生成假数据。
  • 判别网络 D:用于区分真实数据和生成的数据。

3.2 训练过程

  1. 生成假数据

    • 从随机噪声 z 生成假数据:fake_data = G(z)
  2. 计算判别器损失

    • 选择一小批量真实数据 real_data
    • 计算生成的数据与真实数据的判别器输出 D(real_data)D(fake_data)
  3. 梯度惩罚

    • 对真实数据和生成数据进行插值,计算输出来获得梯度
    • 计算梯度惩罚损失 R(D)
  4. 更新判别器

    • 根据 L_D 更新判别器参数 D
  5. 更新生成器

    • 计算生成器损失 L_G 并更新生成器参数 G

4. 示例代码

以下是一个使用 PyTorch 实现 WGAN-GP 的简化示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torch.utils.data import DataLoader

# 定义生成器 G 和判别器 D
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 28 * 28),
nn.Tanh()
)

def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1)
)

def forward(self, img):
return self.model(img.view(-1, 28 * 28))

# 梯度惩罚计算
def gradient_penalty(D, real_data, fake_data):
alpha = torch.rand(real_data.size(0), 1)
interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)

D_interpolated = D(interpolated)
grads = torch.autograd.grad(outputs=D_interpolated, inputs=interpolated,
grad_outputs=torch.ones(D_interpolated.size()),
create_graph=True, retain_graph=True)[0]

grad_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean()
return grad_penalty

# 初始化模型和优化器
G = Generator()
D = Discriminator()
optimizer_G = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))

# 训练模型
for epoch in range(num_epochs):
for i, (real_data, _) in enumerate(dataloader):
batch_size = real_data.size(0)

# 更新判别器
optimizer_D.zero_grad()
z = torch.randn(batch_size, 100)
fake_data = G(z)

real_validity = D(real_data)
fake_validity = D(fake_data.detach())

gp = gradient_penalty(D, real_data, fake_data)
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + 10 * gp # λ = 10

d_loss.backward()
optimizer_D.step()

# 更新生成器
if i % n_critic == 0:
optimizer_G.zero_grad()
z = torch.randn(batch_size, 100)
fake_data = G(z)
g_loss = -torch.mean(D(fake_data))
g_loss.backward()
optimizer_G.step()

说明:

  • 此代码包含生成器和判别器的基本结构,并计算梯度惩罚。
  • n_critic 变量可以根据需要设定(例如,每更新 n_critic 次 D,才更新一次 G)。

5. 总结

WGAN-GP 是一种强大的生成对抗网络,通过 Wasserstein 距离和梯度惩罚显著增强了模型的稳定性和生成质量。对于希望深入学习生成对抗网络的研究人员和开发者来说,WGAN-GP 是一个非常值得学习和实施的方法。

12 从零学生成式对抗网络 (CycleGAN) 教程

12 从零学生成式对抗网络 (CycleGAN) 教程

什么是 CycleGAN?

CycleGAN 是一种生成对抗网络,允许我们在没有成对样本的数据集上进行图像到图像的转换。与传统GAN不同,CycleGAN通过引入循环一致性损失来实现图像转换,而不需要有成对的训练数据。

CycleGAN 的工作原理

CycleGAN 的基本思想是通过两个主要的生成器和判别器网络来实现图像的转换和恢复。图像从一种领域(例如,X)转换为另一种领域(例如,Y),并且能够通过反向操作将其转换回原始领域。

  • 生成器 G:将领域 X 的图像转换为领域 Y 的图像。
  • 生成器 F:将领域 Y 的图像转换为领域 X 的图像。
  • 判别器 D:判断生成的图像是否属于领域 Y
  • 判别器 E:判断生成的图像是否属于领域 X

循环一致性损失

CycleGAN 引入了一个关键概念,即 循环一致性损失,以确保转换是可逆的。这意味着当我们把图像从 X 转换为 Y 再转换回 X 的时候,我们应该能得到原始图像。

循环一致性损失可以表示为:

1
L_{cyc} = ||F(G(X)) - X||_1 + ||G(F(Y)) - Y||_1

其中 || . ||_1 是L1范数,表示两个图像之间的差异。

CycleGAN 的结构

1. 生成器的网络结构

生成器 G 和 F 通常采用 U-Net残差网络。U-Net 通过跳过连接来保留图像的细节。

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
def __init__(self, in_channels):
super(ResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.relu = nn.ReLU()

def forward(self, x):
return x + self.conv2(self.relu(self.conv1(x)))

2. 判别器的网络结构

判别器通常是一个卷积神经网络。其任务是判断输入的图像是真实的还是来自生成器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Discriminator(nn.Module):
def __init__(self, in_channels):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
)

def forward(self, x):
return self.model(x)

损失函数

1. 对抗损失

CycleGAN 使用 对抗损失 来训练判别器。对抗损失确保生成器生成的图像尽可能接近真实图像。

1
L_{GAN}(G, D) = \mathbb{E}_{y \sim Y}[D(y)] - \mathbb{E}_{x \sim X}[D(G(x))]

2. 总损失

CycleGAN 的总损失是对抗损失和循环一致性损失的组合:

1
L_{total} = L_{GAN}(G, D) + L_{GAN}(F, E) + L_{cyc}

训练过程

1. 数据准备

首先需要准备数据集,一般来说需要两个不同风格的图像数据集。例如,晚上的风景图像和白天的风景图像。

2. 定义训练循环

1
2
3
4
5
def train(cycle_gan, data_loader, num_epochs):
for epoch in range(num_epochs):
for i, (real_x, real_y) in enumerate(data_loader):
loss = cycle_gan.step(real_x, real_y)
print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], Loss: {loss:.4f}')

3. 生成图像

训练完成后,可以使用生成器生成转换后的图像:

1
2
with torch.no_grad():
generated_image = cycle_gan.generator_G(real_image)

4. 可视化结果

可以使用 matplotlib 库来可视化生成的图像。

1
2
3
4
5
import matplotlib.pyplot as plt

plt.imshow(generated_image.squeeze().permute(1, 2, 0).numpy())
plt.axis('off')
plt.show()

总结

通过以上步骤,我们可以使用 CycleGAN 将一种风格的图像转换为另一种风格。CycleGAN 不需要成对的数据进行训练,这使得它在许多实际应用中非常有用。希望这个详细的教程能帮助你理解和实现 CycleGAN。