Day 23: GAN Improvements - Enhancing Performance for Fashion MNIST Generation
Welcome to Day 23 of our deep learning challenge! Today, we will discuss the improvements made to the Generative Adversarial Network (GAN) model to generate clearer images for the Fashion MNIST dataset. We’ll explore the changes made to the generator, discriminator, and the overall training process to help enhance the output quality. Let’s dive into the improvements and understand why they were made.
Step 1: Importing Libraries and Loading Data
We start by importing necessary libraries and loading the Fashion MNIST dataset.
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization, Conv2DTranspose, Conv2D
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import fashion_mnist
import pandas as pd
import os
from tensorflow.keras.optimizers import Adam
# Load the data
(X_train, _), (_, _) = fashion_mnist.load_data()
# Normalize between -1 and 1 as it helps tanh activation function.
X_train = (X_train - 127.5) / 127.5
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
- Normalization: The data is normalized between
-1
and1
to match the output range of the tanh activation used in the generator. This helps the model learn more effectively. - Reshaping: The data is reshaped to have a single channel
(28, 28, 1)
to fit the expected input shape of the generator and discriminator.
Step 2: Generator Improvements
Original Generator
The original generator used fully connected layers to upscale the latent space into a 28x28
image. While this approach can work, it struggles with spatial resolution and generating detailed images.
Improved Generator
The improved generator uses transposed convolution layers (Conv2DTranspose
) to better handle upsampling and generate clearer images.
def build_improved_generator():
model = Sequential()
model.add(Dense(7 * 7 * 256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 256)))
# Upsampling to 14x14
model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
# Upsampling to 28x28
model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
# Final layer to generate images
model.add(Conv2D(1, kernel_size=7, activation='tanh', padding='same'))
return model
Changes Explained:
- Dense Layer with Reshape: The generator starts with a Dense layer that outputs a shape of
(7, 7, 256)
followed by reshaping, which helps form a low-resolution base to build upon. - Transposed Convolutions (
Conv2DTranspose
): Instead of fully connected layers, transposed convolutions are used to gradually upscale the image to28x28
. This helps retain spatial hierarchies and generates clearer images. - LeakyReLU Activation: The LeakyReLU activation is used to avoid dead neurons and enhance the flow of gradients. It uses a small slope for negative values (
alpha=0.2
), allowing some negative gradient flow. - Batch Normalization: Helps stabilize training and enables the model to converge faster by normalizing activations.
Step 3: Discriminator Improvements
The original discriminator consisted of fully connected layers which were not ideal for extracting spatial features from images. Therefore, we kept the fully connected layers but considered replacing them with convolutional layers to improve the performance. However, in this version, we still retained the basic structure with some changes in training.
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
# Add layers to process flatten image
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
# Final output layer to classify real(1) or fake(0)
model.add(Dense(1, activation='sigmoid'))
return model
Changes Explained:
- Flatten Layer: The Flatten layer is used to convert the image into a 1D vector.
- LeakyReLU Activation: This activation function is used after each Dense layer, making the network more resilient against the vanishing gradient problem.
Compilation and Learning Rate Changes
- The discriminator uses a learning rate of
0.0002
for the Adam optimizer, and the GAN uses a smaller rate of0.0001
. These lower rates help make the training more stable.
Step 4: Building the GAN
To train both the generator and discriminator as a combined model, we freeze the discriminator’s weights and define the GAN model.
def build_gan(generator, discriminator):
model = Sequential()
model.add(generator)
model.add(discriminator)
return model
discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy')
Changes Explained:
- Discriminator.trainable = False: This line freezes the discriminator while training the GAN. We do not want discriminator weights to update when we are training the generator, as the GAN’s objective is to trick the discriminator.
Step 5: Training the GAN
We made multiple improvements to the training process to help stabilize and enhance the quality of generated images.
# Training Parameters
epochs = 30000
batch_size = 16
save_intervals = 1000
# Label for real and fake images
real = np.ones((batch_size, 1)) * 0.9
fake = np.zeros((batch_size, 1)) + 0.1
for epoch in range(epochs):
# Train the discriminator
random_idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[random_idx]
noise = np.random.normal(0, 1, (batch_size, 100))
generated_images = generator.predict(noise)
# Random flips to add noise to discriminator
if np.random.rand() < 0.1:
real, fake = fake, real
d_loss_real = discriminator.train_on_batch(real_images, real)
d_loss_fake = discriminator.train_on_batch(generated_images, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train the Generator
# Train the generator twice to give it more opportunity.
for _ in range(2):
noise = np.random.normal(0, 1, (batch_size, 100))
gan_loss = gan.train_on_batch(noise, real)
if epoch % save_intervals == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}%] [G loss: {gan_loss}]")
# Save the generated images to visualize training progress
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5 # Rescale from -1 to 1 to 0 to 1
plt.figure(figsize=(5, 5))
for i in range(4):
plt.subplot(2, 2, i + 1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()
Training Improvements:
-
Label Smoothing and Noise: Labels for real and fake images are smoothed (
0.9
for real,0.1
for fake) and occasionally flipped (random flip
) to add noise. This prevents the discriminator from becoming overly confident, which can destabilize GAN training. -
Train Generator Twice: The generator is trained twice per epoch to give it more opportunities to learn and keep up with the discriminator. This helps when the discriminator tends to become too powerful.
-
Random Label Flipping: Randomly flipping real and fake labels during discriminator training further ensures that the discriminator doesn’t become too dominant, which can lead to mode collapse.
-
Save Interval Visualization: We save and plot generated images at intervals (
every 1000 epochs
). This visualization helps track the progress of the GAN and allows us to observe improvements over time.
Summary of Improvements
- Generator Architecture: Improved by using transposed convolutions to better handle spatial upsampling, leading to sharper images.
- Discriminator Training: Smoothing and adding noise to labels, as well as increasing training frequency of the generator, resulted in a more stable training process.
- Learning Rates: Different learning rates for the discriminator and GAN helped maintain balance and stability during training.
- Random Label Flipping: This added robustness to the discriminator’s training and prevented overfitting.
These changes significantly enhanced the performance of the GAN, helping it generate more realistic Fashion MNIST images and stabilizing the training process.