Day 25: Exploring CycleGAN for Style Transfer - Part 1: Model Architecture and Setup
Today, I began implementing CycleGAN, a type of Generative Adversarial Network (GAN) designed for unpaired image-to-image translation. This project focuses on style transfer, specifically converting images of horses into zebras and vice versa. CycleGAN enables style transformation without needing paired datasets, which makes it versatile and powerful for many real-world applications.
In this first part, I focused on:
- Loading and preprocessing the dataset.
- Building the generator and discriminator networks.
- Constructing the combined CycleGAN model.
What is a CycleGAN?
CycleGAN is a GAN-based architecture for unpaired image-to-image translation. Unlike traditional GANs, CycleGAN uses cycle consistency loss, which ensures that translating an image to another domain and back results in the original image.
Key Components:
-
Generators:
G
: Transforms images from Domain A (horses) to Domain B (zebras).F
: Transforms images from Domain B (zebras) to Domain A (horses).
-
Discriminators:
D_A
: Distinguishes real horses from fake horses generated byF
.D_B
: Distinguishes real zebras from fake zebras generated byG
.
-
Cycle Consistency Loss:
- Ensures that when an image is transformed from one domain to another and then back, it closely resembles the original.
Code
# Problem: Implement CycleGAN for style transfer (e.g., horse to zebra conversion)
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, ReLU, BatchNormalization, Concatenate
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
import os
from glob import glob
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.data import Dataset
# Load the dataset
HORSE_DIR = 'dataset/horse2zebra/trainA/'
ZEBRA_DIR = 'dataset/horse2zebra/trainB/'
# Helper function to load images from directories
def load_images_from_directory(directory, size=(128, 128)):
images = []
for filepath in glob(os.path.join(directory, '*.jpg')):
image = load_img(filepath, target_size=size)
image = img_to_array(image)
images.append(image)
return np.array(images)
# Load horses and zebra images
horse_images = load_images_from_directory(HORSE_DIR)
zebra_images = load_images_from_directory(ZEBRA_DIR)
# Normalize it between [-1, 1]
horse_images = (horse_images - 127.5) / 127.5
zebra_images = (zebra_images - 127.5) / 127.5
# Convert to tensorflow dataset and batch them
train_horses = Dataset.from_tensor_slices(horse_images).batch(1)
train_zebras = Dataset.from_tensor_slices(zebra_images).batch(1)
# Build the Generator Model
def build_generator():
inputs = Input(shape=(128, 128, 3))
# Encoder: Downsampling layers
x = Conv2D(64, kernel_size=4, strides=2, padding='same')(inputs)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(256, kernel_size=4, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
# Decoder: Upsampling layers
x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2DTranspose(64, kernel_size=4, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='tanh')(x)
return Model(inputs, x)
# Build generator for both transformations
generator_g = build_generator() # Horse to zebra
generator_f = build_generator() # Zebra to horse
generator_g.summary()
generator_f.summary()
# Build the discriminator Model
# Define the discriminator model
def build_discriminator():
inputs = Input(shape=(128, 128, 3))
x = Conv2D(64, kernel_size=4, strides=2, padding='same')(inputs)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(256, kernel_size=4, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(1, kernel_size=4, padding='same')(x)
return Model(inputs, x)
# Build the dicriminator for both domains
discriminator_a = build_discriminator() # For Domain A (Horses)
discriminator_b = build_discriminator() # For Domain B (Zebras)
discriminator_a.compile(
optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
loss='mse',
metrics=['accuracy']
)
discriminator_b.compile(
optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
loss='mse',
metrics=['accuracy']
)
discriminator_a.summary()
discriminator_b.summary()
# Build Cycle GAN Model
# Define the combined cycle GAN model
def build_combined(generator_g, generator_f, discriminator_a, discriminator_b):
discriminator_a.trainable = False
discriminator_b.trainable = False
# Real input images for both the domain
input_a = Input(shape=(128, 128, 3)) # Horses
input_b = Input(shape=(128, 128, 3)) # Zebras
# Forward cycle: A -> B -> A
fake_b = generator_g(input_a)
cycle_a = generator_f(fake_b)
# Backward cycle: B -> A -> B
fake_a = generator_f(input_b)
cycle_b = generator_g(fake_a)
# Identifying mapping preserving original features
same_a = generator_f(input_a)
same_b = generator_g(input_b)
# Discriminators for the generated images
valid_a = discriminator_a(fake_a)
valid_b = discriminator_b(fake_b)
# Define the combined model
model = Model(
inputs=[input_a, input_b],
outputs=[
valid_a,
valid_b,
cycle_a,
cycle_b,
same_a,
same_b
]
)
model.compile(
optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
loss=['mse', 'mse', 'mse', 'mse', 'mse', 'mse'],
loss_weights=[1, 1, 10, 10, 5, 5]
)
return model
combined_model = build_combined(generator_g, generator_f, discriminator_a, discriminator_b)
combined_model.summary()
Step 1: Dataset Preparation
I used the horse2zebra dataset from CycleGAN’s original implementation.
-
Loading Images:
- Images were loaded using the
glob
module to iterate over files in the dataset folders. - Resized to
128x128
for faster computation.
- Images were loaded using the
-
Normalization:
- Pixel values were normalized to the range
[-1, 1]
to match the output of the generator’stanh
activation function.
- Pixel values were normalized to the range
-
Batching:
- The preprocessed images were converted into TensorFlow datasets and batched for training.
Step 2: Building the Generator
The generator architecture consists of:
- Encoder: Downsampling layers using
Conv2D
andLeakyReLU
. - Decoder: Upsampling layers using
Conv2DTranspose
andReLU
. - Final layer with a
tanh
activation to generate images in the range[-1, 1]
.
Two Generators:
G
: Converts horses to zebras.F
: Converts zebras to horses.
Step 3: Building the Discriminator
The discriminator architecture:
- Uses Conv2D layers for feature extraction and downsampling.
- LeakyReLU activation is applied after each layer.
- Outputs a single value indicating whether the input image is real or fake.
Two Discriminators:
D_A
: Classifies images in Domain A (horses).D_B
: Classifies images in Domain B (zebras).
Each discriminator is trained to minimize the mean squared error (mse
) loss.
Step 4: Combining the Models
The combined CycleGAN model includes:
- Forward Cycle:
A -> B -> A
: Translates a horse to a zebra and back to a horse.
- Backward Cycle:
B -> A -> B
: Translates a zebra to a horse and back to a zebra.
- Identity Mapping:
G(A) ≈ A
: Ensures that translating an image from one domain to itself preserves its features.
- Discriminator Feedback:
D_A
andD_B
provide feedback to the generators.
Loss Functions:
- Adversarial Loss: Encourages generators to produce realistic images.
- Cycle Consistency Loss: Penalizes discrepancies between input and reconstructed images.
- Identity Loss: Preserves color and style during translation.
Challenges Faced
- Balancing Loss Terms: Combining multiple loss terms with different weights required careful tuning.
- Resource Requirements: The model is computationally intensive due to its two generators and two discriminators.
- Image Quality: Initial results were blurry, likely due to limited training data and early-stage model adjustments.
Next Steps
In the next part of this project, I’ll focus on:
- Training the CycleGAN model:
- Implementing the training loop for the generators and discriminators.
- Monitoring loss values and generated images to ensure convergence.
- Evaluating Results:
- Visualizing transformed images during training.
- Comparing generated outputs to assess cycle consistency.
This is an exciting exploration into unpaired style transfer with CycleGAN. Stay tuned for Part 2, where I’ll dive into training the model and visualizing the results!