1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| import numpy as np import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data() x_train = x_train.astype('float32') / 255.0 x_train = np.expand_dims(x_train, axis=-1)
def build_generator(z_dim): model = tf.keras.Sequential() model.add(layers.Dense(256, input_dim=z_dim)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(512)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(1024)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(28 * 28 * 1, activation='tanh')) model.add(layers.Reshape((28, 28, 1))) return model
def build_discriminator(): model = tf.keras.Sequential() model.add(layers.Flatten(input_shape=(28, 28, 1))) model.add(layers.Dense(512)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(256)) model.add(layers.LeakyReLU(alpha=0.2)) model.add(layers.Dense(1, activation='sigmoid')) return model
z_dim = 100 generator = build_generator(z_dim) discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) discriminator.trainable = False
gan_input = layers.Input(shape=(z_dim,)) fake_image = generator(gan_input) discriminator_output = discriminator(fake_image) gan = tf.keras.Model(gan_input, discriminator_output) gan.compile(loss='binary_crossentropy', optimizer='adam')
def train_gan(epochs=10000, batch_size=128): for e in range(epochs): idx = np.random.randint(0, x_train.shape[0], batch_size) real_images = x_train[idx] noise = np.random.normal(0, 1, (batch_size, z_dim)) fake_images = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1))) d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, z_dim)) g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
if e % 1000 == 0: print(f'Epoch: {e}, D_loss: {d_loss[0]}, G_loss: {g_loss}')
train_gan()
|