Day 24: Exploring Conditional GANs (CGANs) With Fashion MNIST
Today, I dived into Conditional GANs (CGANs), an exciting variation of GANs that allows for generating specific types of images conditioned on labels. Using the Fashion MNIST dataset, I implemented a CGAN to generate images of specific clothing items like shirts, shoes, and bags.
What is a Conditional GAN?
A Conditional GAN (CGAN) is an extension of GANs where the generation of data is conditioned on some additional information, such as labels or attributes. In this project:
- The generator takes both noise (random input) and a label (e.g., “sneaker”).
- The discriminator evaluates whether an image-label pair is real or fake.
This setup allows the CGAN to generate specific types of images based on the provided label.
Code:
# Problem: Conditional GAN (CGAN) for Generating Specific Images from Fashion MNIST
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization, Conv2DTranspose, Conv2D, \
Input, Concatenate
from tensorflow.keras.models import Sequential, Model
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
# Load the Fashion MNIST dataset
(X_train, y_train), (_, _) = fashion_mnist.load_data()
# Normalize the images to the range [-1, 1] to fit the tanh activation function in the generator
X_train = (X_train - 127.5) / 127.5
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
# One-hot encode the labels for conditioning
num_classes = 10
y_train = to_categorical(y_train, num_classes)
# Function to build the generator
def build_generator():
# Inputs for the generator
noise_input = Input(shape=(100,))
label_input = Input(shape=(num_classes,))
# Concatenate noise and label to create the input for the generator
model_input = Concatenate()([noise_input, label_input])
x = Dense(7 * 7 * 256, activation='relu')(model_input)
x = Reshape((7, 7, 256))(x)
x = BatchNormalization(momentum=0.8)(x)
# Upsample to 14x14
x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = BatchNormalization(momentum=0.8)(x)
# Upsample to 28x28
x = Conv2DTranspose(64, kernel_size=4, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = BatchNormalization(momentum=0.8)(x)
# Final layer to generate an image with 28x28 dimensions and 1 channel
img_output = Conv2D(1, kernel_size=7, activation='tanh', padding='same')(x)
return Model([noise_input, label_input], img_output)
# Build the generator model
generator = build_generator()
generator.summary()
# Function to build the discriminator
def build_discriminator():
# Inputs for the image and the label
img_input = Input(shape=(28, 28, 1))
label_input = Input(shape=(num_classes,))
# Embed the label and reshape to match the image shape
label_embedding = Dense(28 * 28)(label_input)
label_embedding = Reshape((28, 28, 1))(label_embedding)
# Concatenate the image and label embedding
combined_input = Concatenate()([img_input, label_embedding])
# Flatten the combined input and pass through dense layers
x = Flatten()(combined_input)
x = Dense(512)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dense(256)(x)
x = LeakyReLU(alpha=0.2)(x)
# Final output layer to classify real (1) or fake (0)
validity_output = Dense(1, activation='sigmoid')(x)
# Create the model that takes the image and label as input
return Model([img_input, label_input], validity_output)
# Build and compile the discriminator model
discriminator = build_discriminator()
discriminator.compile(
optimizer=Adam(learning_rate=0.0002),
loss='binary_crossentropy',
metrics=['accuracy']
)
discriminator.summary()
# Build the combined CGAN model
# Freeze the discriminator's layers when training the combined CGAN model
discriminator.trainable = False
# Inputs for noise and label
noise_input = Input(shape=(100,))
label_input = Input(shape=(num_classes,))
# Generate an image from the noise and label input
img = generator([noise_input, label_input])
# Use the discriminator to classify the generated image with the label
validity = discriminator([img, label_input])
# Define the combined CGAN model
cgan = Model([noise_input, label_input], validity)
cgan.compile(optimizer=Adam(learning_rate=0.0002), loss='binary_crossentropy')
cgan.summary()
# Training the CGAN
# Training Parameters
epochs = 10000
batch_size = 32
save_interval = 1000
# Labels for real and fake images
real = np.ones((batch_size, 1)) * 0.9 # Smoothed label for real images
fake = np.zeros((batch_size, 1)) + 0.1 # Noisy label for fake images
for epoch in range(epochs):
# Train the discriminator with real images
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]
real_labels = y_train[idx]
# Generate fake images
noise = np.random.normal(0, 1, (batch_size, 100))
fake_labels = np.eye(num_classes)[np.random.choice(num_classes, batch_size)]
gen_imgs = generator.predict([noise, fake_labels])
# Train the discriminator on real and fake images
d_loss_real = discriminator.train_on_batch([real_imgs, real_labels], real)
d_loss_fake = discriminator.train_on_batch([gen_imgs, fake_labels], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train the generator via the combined CGAN model
noise = np.random.normal(0, 1, (batch_size, 100))
sampled_labels = np.eye(num_classes)[np.random.choice(num_classes, batch_size)]
g_loss = cgan.train_on_batch([noise, sampled_labels], real)
# Display training progress and save images at intervals
if epoch % save_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}%] [G loss: {g_loss}]")
# Save generated images to visualize training progress
generated_imgs = generator.predict([noise, sampled_labels])
generated_imgs = 0.5 * generated_imgs + 0.5 # Rescale from [-1, 1] to [0, 1]
plt.figure(figsize=(5, 5))
for i in range(4):
plt.subplot(2, 2, i + 1)
plt.imshow(generated_imgs[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()
Steps Implemented
Step 1: Dataset Preparation
- Dataset: I used the Fashion MNIST dataset, which contains grayscale 28x28 images of clothing items across 10 classes.
- Normalization: Images were normalized to the range
[-1, 1]
to match the output of the generator’stanh
activation. - One-Hot Encoding: Labels were converted to a one-hot encoding format to condition both the generator and discriminator.
Step 2: Building the Generator
The generator takes:
- Noise: A random 100-dimensional vector.
- Label: A one-hot encoded label vector.
The inputs are concatenated and passed through:
- Dense and reshaping layers to form a low-resolution image.
- Transposed convolutions to upsample the image to 28x28.
- Batch normalization to stabilize training and speed up convergence.
- A final layer with
tanh
activation to output an image.
Objective: Generate a realistic image conditioned on the label.
Step 3: Building the Discriminator
The discriminator takes:
- Image: A 28x28 grayscale image.
- Label: A one-hot encoded label embedded and reshaped to match the image dimensions.
The inputs are concatenated and passed through:
- Dense layers with LeakyReLU activation for feature extraction.
- A final dense layer with
sigmoid
activation to classify the input as real or fake.
Objective: Distinguish between real and fake image-label pairs.
Step 4: Building the CGAN Model
The CGAN combines the generator and discriminator:
- The generator outputs a fake image given noise and a label.
- The discriminator evaluates the generated image-label pair.
- The generator is trained to fool the discriminator, encouraging it to produce realistic images.
The discriminator’s layers are frozen while training the combined CGAN model.
Step 5: Training the CGAN
- Discriminator Training:
- Trained on real and fake images with smoothed labels to improve stability.
- Generator Training:
- Trained using the combined CGAN model to maximize the discriminator’s classification error on fake images.
Step 6: Visualizing Results
During training:
- Images were generated at regular intervals to monitor progress.
- Noise and labels were sampled to visualize specific clothing items like “sneakers” or “shirts.”