10 WGAN (Wasserstein GAN) 详细教程
1. 引言
生成对抗网络(GAN)是一种强大的生成模型,但在训练过程中常常会遭遇不稳定性和模式崩溃等问题。为了解决这些问题,Wetzerstein GAN(WGAN)有效地改进了GAN的训练方式,通过引入Wasserstein距离
来度量真实数据和生成数据之间的差异。WGAN具有更稳定的训练过程和更好的生成效果。
2. WGAN 基础概念
2.1 Wasserstein 距离
- Wasserstein 距离,又称地球移动者距离(Earth Mover’s Distance, EMD),用于量化两个概率分布之间的差异。
- 与传统的 Jensen-Shannon Divergence 相比,Wasserstein 距离对生成分布的支持情况更加敏感,这使得在训练时模型能够更好地捕捉数据分布的变化。
2.2 作图
在训练过程中,WGAN追求的是:
$$
W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma} [| x - y |]
$$
其中,$\Pi(P_r, P_g)$ 是所有将真实分布 $P_r$ 和生成分布 $P_g$ 结合的联合分布的集合。
3. WGAN 的模型结构
WGAN 的结构与标准 GAN 类似,包括一个生成器(Generator, G)和一个判别器(Discriminator, D)。但是,WGAN 的判别器不再是一个概率输出模型,而是一个 Lipschitz
连续函数。为了满足这一条件,通常使用一个权重裁剪或梯度惩罚的方法。
3.1 生成器(G)
- 输入:随机噪声 $\mathbf{z}$,通常从一个均匀分布或正态分布中采样。
- 输出:生成样本 $\mathbf{G(z)}$。
3.2 判别器(D)
- 输入:数据样本(真实样本或生成样本)。
- 输出:一个实数值,用于反映样本的真实性。
4. WGAN 训练过程
4.1 损失函数
WGAN 的损失函数定义如下:
- 对于生成器:
$$
L_G = -\mathbb{E}[D(G(z))]
$$
- 对于判别器:
$$
L_D = \mathbb{E}[D(x)] - \mathbb{E}[D(G(z))] + \lambda \cdot \mathbb{E}[(|\nabla D(x) |_2 - 1)^2]
$$
其中,$\lambda$ 是梯度惩罚的超参数。
4.2 梯度惩罚
为了满足 Lipschitz
条件,而不是简单地裁剪权重,可以通过以下方式来加以限制:
- 从真实样本和生成样本中间随机采样,生成新的样本 $\hat{x} = \epsilon \cdot x + (1 - \epsilon) \cdot G(z)$,其中 $\epsilon$ 从均匀分布 $U(0, 1)$ 中采样。
- 计算判别器对 $\hat{x}$ 输入的梯度。
- 添加梯度惩罚项。
5. WGAN实例代码
以下是一个用 TensorFlow/Keras 实现 WGAN 的基本示例。
import tensorflow as tf
from tensorflow.keras import layers
# 生成器
def build_generator(z_dim):
model = tf.keras.Sequential()
model.add(layers.Dense(128, input_dim=z_dim))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(28 * 28 * 1, activation='tanh')) # MNIST长度
model.add(layers.Reshape((28, 28, 1)))
return model
# 判别器
def build_critic(img_shape):
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=img_shape))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU(alpha=0.2))
model.add(layers.Dense(1)) # 实数值
return model
# 梯度惩罚
def gradient_penalty(critic, real_images, fake_images):
batch_size = real_images.shape[0]
epsilon = tf.random.uniform((batch_size, 1, 1, 1), 0.0, 1.0)
interpolated_images = epsilon * real_images + (1 - epsilon) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated_images)
D_interpolated = critic(interpolated_images)
gradients = tape.gradient(D_interpolated, interpolated_images)[0]
gradient_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1))
return tf.reduce_mean((gradient_norm - 1.0) ** 2)
# 示例训练过程
def train_wgan(generator, critic, epochs, batch_size):
for epoch in range(epochs):
for _ in range(5): # 更新判别器更多次
# 获取批次数据 (这里使用随机生成假数据)
noise = tf.random.normal((batch_size, z_dim))
real_images = ... # 这里应该加载真实数据
fake_images = generator(noise)
with tf.GradientTape() as tape:
D_real = critic(real_images)
D_fake = critic(fake_images)
gp = gradient_penalty(critic, real_images, fake_images)
D_loss = tf.reduce_mean(D_fake) - tf.reduce_mean(D_real) + 10 * gp # 10是lambda
grads = tape.gradient(D_loss, critic.trainable_variables)
critic.optimizer.apply_gradients(zip(grads, critic.trainable_variables))
# 更新生成器
noise = tf.random.normal((batch_size, z_dim))
with tf.GradientTape() as tape:
fake_images = generator(noise)
D_fake = critic(fake_images)
G_loss = -tf.reduce_mean(D_fake)
grads = tape.gradient(G_loss, generator.trainable_variables)
generator.optimizer.apply_gradients(zip(grads, generator.trainable_variables))
print(f'Epoch: