Day 24: Exploring Conditional GANs (CGANs) With Fashion MNIST

Today, I dived into Conditional GANs (CGANs), an exciting variation of GANs that allows for generating specific types of images conditioned on labels. Using the Fashion MNIST dataset, I implemented a CGAN to generate images of specific clothing items like shirts, shoes, and bags.

What is a Conditional GAN?

A Conditional GAN (CGAN) is an extension of GANs where the generation of data is conditioned on some additional information, such as labels or attributes. In this project:

  • The generator takes both noise (random input) and a label (e.g., “sneaker”).
  • The discriminator evaluates whether an image-label pair is real or fake.

This setup allows the CGAN to generate specific types of images based on the provided label.

Code:

# Problem: Conditional GAN (CGAN) for Generating Specific Images from Fashion MNIST

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization, Conv2DTranspose, Conv2D, \
    Input, Concatenate
from tensorflow.keras.models import Sequential, Model
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

# Load the Fashion MNIST dataset
(X_train, y_train), (_, _) = fashion_mnist.load_data()

# Normalize the images to the range [-1, 1] to fit the tanh activation function in the generator
X_train = (X_train - 127.5) / 127.5
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')

# One-hot encode the labels for conditioning
num_classes = 10
y_train = to_categorical(y_train, num_classes)


# Function to build the generator
def build_generator():
    # Inputs for the generator
    noise_input = Input(shape=(100,))
    label_input = Input(shape=(num_classes,))

    # Concatenate noise and label to create the input for the generator
    model_input = Concatenate()([noise_input, label_input])

    x = Dense(7 * 7 * 256, activation='relu')(model_input)
    x = Reshape((7, 7, 256))(x)
    x = BatchNormalization(momentum=0.8)(x)

    # Upsample to 14x14
    x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)

    # Upsample to 28x28
    x = Conv2DTranspose(64, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)

    # Final layer to generate an image with 28x28 dimensions and 1 channel
    img_output = Conv2D(1, kernel_size=7, activation='tanh', padding='same')(x)

    return Model([noise_input, label_input], img_output)


# Build the generator model
generator = build_generator()
generator.summary()


# Function to build the discriminator
def build_discriminator():
    # Inputs for the image and the label
    img_input = Input(shape=(28, 28, 1))
    label_input = Input(shape=(num_classes,))

    # Embed the label and reshape to match the image shape
    label_embedding = Dense(28 * 28)(label_input)
    label_embedding = Reshape((28, 28, 1))(label_embedding)

    # Concatenate the image and label embedding
    combined_input = Concatenate()([img_input, label_embedding])

    # Flatten the combined input and pass through dense layers
    x = Flatten()(combined_input)
    x = Dense(512)(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Dense(256)(x)
    x = LeakyReLU(alpha=0.2)(x)

    # Final output layer to classify real (1) or fake (0)
    validity_output = Dense(1, activation='sigmoid')(x)

    # Create the model that takes the image and label as input
    return Model([img_input, label_input], validity_output)


# Build and compile the discriminator model
discriminator = build_discriminator()
discriminator.compile(
    optimizer=Adam(learning_rate=0.0002),
    loss='binary_crossentropy',
    metrics=['accuracy']
)
discriminator.summary()

# Build the combined CGAN model
# Freeze the discriminator's layers when training the combined CGAN model
discriminator.trainable = False

# Inputs for noise and label
noise_input = Input(shape=(100,))
label_input = Input(shape=(num_classes,))

# Generate an image from the noise and label input
img = generator([noise_input, label_input])

# Use the discriminator to classify the generated image with the label
validity = discriminator([img, label_input])

# Define the combined CGAN model
cgan = Model([noise_input, label_input], validity)
cgan.compile(optimizer=Adam(learning_rate=0.0002), loss='binary_crossentropy')
cgan.summary()

# Training the CGAN

# Training Parameters
epochs = 10000
batch_size = 32
save_interval = 1000

# Labels for real and fake images
real = np.ones((batch_size, 1)) * 0.9  # Smoothed label for real images
fake = np.zeros((batch_size, 1)) + 0.1  # Noisy label for fake images

for epoch in range(epochs):

    # Train the discriminator with real images
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_imgs = X_train[idx]
    real_labels = y_train[idx]

    # Generate fake images
    noise = np.random.normal(0, 1, (batch_size, 100))
    fake_labels = np.eye(num_classes)[np.random.choice(num_classes, batch_size)]
    gen_imgs = generator.predict([noise, fake_labels])

    # Train the discriminator on real and fake images
    d_loss_real = discriminator.train_on_batch([real_imgs, real_labels], real)
    d_loss_fake = discriminator.train_on_batch([gen_imgs, fake_labels], fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train the generator via the combined CGAN model
    noise = np.random.normal(0, 1, (batch_size, 100))
    sampled_labels = np.eye(num_classes)[np.random.choice(num_classes, batch_size)]
    g_loss = cgan.train_on_batch([noise, sampled_labels], real)

    # Display training progress and save images at intervals
    if epoch % save_interval == 0:
        print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}%] [G loss: {g_loss}]")

        # Save generated images to visualize training progress
        generated_imgs = generator.predict([noise, sampled_labels])
        generated_imgs = 0.5 * generated_imgs + 0.5  # Rescale from [-1, 1] to [0, 1]

        plt.figure(figsize=(5, 5))
        for i in range(4):
            plt.subplot(2, 2, i + 1)
            plt.imshow(generated_imgs[i, :, :, 0], cmap='gray')
            plt.axis('off')
        plt.show()

Steps Implemented

Step 1: Dataset Preparation

  • Dataset: I used the Fashion MNIST dataset, which contains grayscale 28x28 images of clothing items across 10 classes.
  • Normalization: Images were normalized to the range [-1, 1] to match the output of the generator’s tanh activation.
  • One-Hot Encoding: Labels were converted to a one-hot encoding format to condition both the generator and discriminator.

Step 2: Building the Generator

The generator takes:

  • Noise: A random 100-dimensional vector.
  • Label: A one-hot encoded label vector.

The inputs are concatenated and passed through:

  1. Dense and reshaping layers to form a low-resolution image.
  2. Transposed convolutions to upsample the image to 28x28.
  3. Batch normalization to stabilize training and speed up convergence.
  4. A final layer with tanh activation to output an image.

Objective: Generate a realistic image conditioned on the label.

Step 3: Building the Discriminator

The discriminator takes:

  • Image: A 28x28 grayscale image.
  • Label: A one-hot encoded label embedded and reshaped to match the image dimensions.

The inputs are concatenated and passed through:

  1. Dense layers with LeakyReLU activation for feature extraction.
  2. A final dense layer with sigmoid activation to classify the input as real or fake.

Objective: Distinguish between real and fake image-label pairs.

Step 4: Building the CGAN Model

The CGAN combines the generator and discriminator:

  • The generator outputs a fake image given noise and a label.
  • The discriminator evaluates the generated image-label pair.
  • The generator is trained to fool the discriminator, encouraging it to produce realistic images.

The discriminator’s layers are frozen while training the combined CGAN model.

Step 5: Training the CGAN

  • Discriminator Training:
    • Trained on real and fake images with smoothed labels to improve stability.
  • Generator Training:
    • Trained using the combined CGAN model to maximize the discriminator’s classification error on fake images.

Step 6: Visualizing Results

During training:

  • Images were generated at regular intervals to monitor progress.
  • Noise and labels were sampled to visualize specific clothing items like “sneakers” or “shirts.”

Video