1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
| import numpy as np import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data() X_train = X_train / 255.0 X_train = np.expand_dims(X_train, axis=-1)
def build_generator(): model = tf.keras.Sequential() model.add(layers.Dense(128, activation='relu', input_shape=(100,))) model.add(layers.Dense(784, activation='sigmoid')) model.add(layers.Reshape((28, 28, 1))) return model
def build_discriminator(): model = tf.keras.Sequential() model.add(layers.Flatten(input_shape=(28, 28, 1))) model.add(layers.Dense(128, activation='relu')) model.add(layers.Dense(1, activation='sigmoid')) return model
generator = build_generator() discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
discriminator.trainable = False gan_input = layers.Input(shape=(100,)) generated_image = generator(gan_input) gan_output = discriminator(generated_image)
gan = tf.keras.models.Model(gan_input, gan_output) gan.compile(loss='binary_crossentropy', optimizer='adam')
def train_gan(epochs=1, batch_size=128): for epoch in range(epochs): for _ in range(X_train.shape[0] // batch_size): noise = np.random.normal(0, 1, size=[batch_size, 100]) generated_images = generator.predict(noise)
image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]
X = np.concatenate([image_batch, generated_images]) y = np.zeros(2 * batch_size) y[:batch_size] = 1
discriminator.trainable = True discriminator.train_on_batch(X, y)
noise = np.random.normal(0, 1, size=[batch_size, 100]) y_gen = np.ones(batch_size) discriminator.trainable = False gan.train_on_batch(noise, y_gen)
train_gan(epochs=100, batch_size=128)
def plot_generated_images(generator, n_examples=10): noise = np.random.normal(0, 1, size=[n_examples, 100]) generated_images = generator.predict(noise)
plt.figure(figsize=(10, 1)) for i in range(n_examples): plt.subplot(1, n_examples, i+1) plt.imshow(generated_images[i].reshape(28, 28), cmap='gray') plt.axis('off') plt.show()
plot_generated_images(generator)
|