Day 23: GAN Improvements - Enhancing Performance for Fashion MNIST Generation

Welcome to Day 23 of our deep learning challenge! Today, we will discuss the improvements made to the Generative Adversarial Network (GAN) model to generate clearer images for the Fashion MNIST dataset. We’ll explore the changes made to the generator, discriminator, and the overall training process to help enhance the output quality. Let’s dive into the improvements and understand why they were made.

Step 1: Importing Libraries and Loading Data

We start by importing necessary libraries and loading the Fashion MNIST dataset.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization, Conv2DTranspose, Conv2D
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import fashion_mnist
import pandas as pd
import os
from tensorflow.keras.optimizers import Adam

# Load the data
(X_train, _), (_, _) = fashion_mnist.load_data()

# Normalize between -1 and 1 as it helps tanh activation function.
X_train = (X_train - 127.5) / 127.5
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
  • Normalization: The data is normalized between -1 and 1 to match the output range of the tanh activation used in the generator. This helps the model learn more effectively.
  • Reshaping: The data is reshaped to have a single channel (28, 28, 1) to fit the expected input shape of the generator and discriminator.

Step 2: Generator Improvements

Original Generator

The original generator used fully connected layers to upscale the latent space into a 28x28 image. While this approach can work, it struggles with spatial resolution and generating detailed images.

Improved Generator

The improved generator uses transposed convolution layers (Conv2DTranspose) to better handle upsampling and generate clearer images.

def build_improved_generator():
    model = Sequential()

    model.add(Dense(7 * 7 * 256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 256)))

    # Upsampling to 14x14
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # Upsampling to 28x28
    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    # Final layer to generate images
    model.add(Conv2D(1, kernel_size=7, activation='tanh', padding='same'))

    return model

Changes Explained:

  • Dense Layer with Reshape: The generator starts with a Dense layer that outputs a shape of (7, 7, 256) followed by reshaping, which helps form a low-resolution base to build upon.
  • Transposed Convolutions (Conv2DTranspose): Instead of fully connected layers, transposed convolutions are used to gradually upscale the image to 28x28. This helps retain spatial hierarchies and generates clearer images.
  • LeakyReLU Activation: The LeakyReLU activation is used to avoid dead neurons and enhance the flow of gradients. It uses a small slope for negative values (alpha=0.2), allowing some negative gradient flow.
  • Batch Normalization: Helps stabilize training and enables the model to converge faster by normalizing activations.

Step 3: Discriminator Improvements

The original discriminator consisted of fully connected layers which were not ideal for extracting spatial features from images. Therefore, we kept the fully connected layers but considered replacing them with convolutional layers to improve the performance. However, in this version, we still retained the basic structure with some changes in training.

def build_discriminator():
    model = Sequential()

    model.add(Flatten(input_shape=(28, 28, 1)))

    # Add layers to process flatten image
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))

    # Final output layer to classify real(1) or fake(0)
    model.add(Dense(1, activation='sigmoid'))

    return model

Changes Explained:

  • Flatten Layer: The Flatten layer is used to convert the image into a 1D vector.
  • LeakyReLU Activation: This activation function is used after each Dense layer, making the network more resilient against the vanishing gradient problem.

Compilation and Learning Rate Changes

  • The discriminator uses a learning rate of 0.0002 for the Adam optimizer, and the GAN uses a smaller rate of 0.0001. These lower rates help make the training more stable.

Step 4: Building the GAN

To train both the generator and discriminator as a combined model, we freeze the discriminator’s weights and define the GAN model.

def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy')

Changes Explained:

  • Discriminator.trainable = False: This line freezes the discriminator while training the GAN. We do not want discriminator weights to update when we are training the generator, as the GAN’s objective is to trick the discriminator.

Step 5: Training the GAN

We made multiple improvements to the training process to help stabilize and enhance the quality of generated images.

# Training Parameters
epochs = 30000
batch_size = 16
save_intervals = 1000

# Label for real and fake images
real = np.ones((batch_size, 1)) * 0.9
fake = np.zeros((batch_size, 1)) + 0.1

for epoch in range(epochs):
    # Train the discriminator
    random_idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_images = X_train[random_idx]

    noise = np.random.normal(0, 1, (batch_size, 100))
    generated_images = generator.predict(noise)

    # Random flips to add noise to discriminator
    if np.random.rand() < 0.1:
        real, fake = fake, real

    d_loss_real = discriminator.train_on_batch(real_images, real)
    d_loss_fake = discriminator.train_on_batch(generated_images, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train the Generator
    # Train the generator twice to give it more opportunity.
    for _ in range(2):
        noise = np.random.normal(0, 1, (batch_size, 100))
        gan_loss = gan.train_on_batch(noise, real)

    if epoch % save_intervals == 0:
        print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}%] [G loss: {gan_loss}]")

        # Save the generated images to visualize training progress
        generated_images = generator.predict(noise)
        generated_images = 0.5 * generated_images + 0.5  # Rescale from -1 to 1 to 0 to 1

        plt.figure(figsize=(5, 5))
        for i in range(4):
            plt.subplot(2, 2, i + 1)
            plt.imshow(generated_images[i, :, :, 0], cmap='gray')
            plt.axis('off')
        plt.show()

Training Improvements:

  1. Label Smoothing and Noise: Labels for real and fake images are smoothed (0.9 for real, 0.1 for fake) and occasionally flipped (random flip) to add noise. This prevents the discriminator from becoming overly confident, which can destabilize GAN training.

  2. Train Generator Twice: The generator is trained twice per epoch to give it more opportunities to learn and keep up with the discriminator. This helps when the discriminator tends to become too powerful.

  3. Random Label Flipping: Randomly flipping real and fake labels during discriminator training further ensures that the discriminator doesn’t become too dominant, which can lead to mode collapse.

  4. Save Interval Visualization: We save and plot generated images at intervals (every 1000 epochs). This visualization helps track the progress of the GAN and allows us to observe improvements over time.

Summary of Improvements

  • Generator Architecture: Improved by using transposed convolutions to better handle spatial upsampling, leading to sharper images.
  • Discriminator Training: Smoothing and adding noise to labels, as well as increasing training frequency of the generator, resulted in a more stable training process.
  • Learning Rates: Different learning rates for the discriminator and GAN helped maintain balance and stability during training.
  • Random Label Flipping: This added robustness to the discriminator’s training and prevented overfitting.

These changes significantly enhanced the performance of the GAN, helping it generate more realistic Fashion MNIST images and stabilizing the training process.

Video