Day 26: Training the CycleGAN for Style Transfer - Part 2: Horse to Zebra Conversion

Today, I trained the CycleGAN model that I implemented on Day 25. The training involved optimizing both the generators and discriminators, ensuring that the generated images preserve the style and features of the target domain. This marks an exciting step in the process of unpaired image-to-image translation.

Code:

# Set up the directories to save the generted images
if not os.path.exists('generated_images'):
    os.makedirs("generated_images")

# Define the Hyper-parameters
EPOCHS = 100
BATCH_SIZE = 1
SAVE_INTERVAL = 10

# Label for real and fake images for training discriminator
REAL_LABEL = np.ones((BATCH_SIZE, 16, 16, 1)) # Will be of shape 16x16 with 1 channel
FAKE_LABEL = np.zeros((BATCH_SIZE, 16, 16, 1))

for epoch in range(EPOCHS):
    for real_a, real_b in Dataset.zip((train_horses, train_zebras)).take(100):
        # Generate fake images using the generator
        fake_a = generator_g.predict(real_a)
        fake_b = generator_f.predict(real_b)

        # Train discriminator with real and fake images
        # Train Discriminator A
        d_a_real_loss = discriminator_a.train_on_batch(real_a, REAL_LABEL)
        d_a_fake_loss = discriminator_a.train_on_batch(fake_a, FAKE_LABEL)
        d_a_loss = 0.5 * np.add(d_a_real_loss, d_a_fake_loss)

        # Train Discriminator B
        d_b_real_loss = discriminator_b.train_on_batch(real_b, REAL_LABEL)
        d_b_fake_loss = discriminator_b.train_on_batch(fake_b, FAKE_LABEL)
        d_b_loss = 0.5 * np.add(d_b_real_loss, d_b_fake_loss)

        # Train generator to fool discriminator and maintain cycle consistency
        g_loss = combined_model.train_on_batch([real_a, real_b], [REAL_LABEL, REAL_LABEL, real_a, real_b, real_a, real_b])

    # Print the progress
    print(f"Epoch: {epoch + 1} / {EPOCHS}")
    print(f"D_A_Loss: {d_a_loss[0]:.4f}, D_B_LOSS: {d_b_loss[0]:.4f}")
    print(f"G_loss: {g_loss}")

    # Save generated images at regular interval.
    if (epoch + 1) % SAVE_INTERVAL == 0:
        fake_a = generator_g.predict(real_a)
        fake_b = generator_f.predict(real_b)

        # Visualize the generated images
        plt.figure(figsize=(10, 6))

        plt.subplot(2, 2, 1)
        plt.title("Original Horse")
        plt.imshow((real_a[0] + 1) / 2) # Rescale to [0,1]
        plt.axis('off')

        plt.subplot(2, 2, 2)
        plt.title("Generated Zebra")
        plt.imshow((fake_b[0] + 1) / 2) # Rescale to [0,1]
        plt.axis('off')

        plt.subplot(2, 2, 3)
        plt.title("Original Zebra")
        plt.imshow((real_b[0] + 1) / 2) # Rescale to [0,1]
        plt.axis('off')

        plt.subplot(2, 2, 4)
        plt.title("Generated Horse")
        plt.imshow((fake_a[0] + 1) / 2) # Rescale to [0,1]
        plt.axis('off')

        plt.savefig(f"generated_images/epochs_{epoch + 1}.png")

        plt.show()

Step 1: Setting Up the Training Loop

  • Epochs: The model was trained for 100 epochs.
  • Batch Size: A batch size of 1 was used, as CycleGAN models typically use smaller batches for better style preservation.
  • Saving Interval: Generated images were saved every 10 epochs for visual evaluation.

Step 2: Loss Functions

  1. Discriminator Loss:

    • For each discriminator (D_A for horses and D_B for zebras), real and fake images were passed through the network.
    • The losses for real and fake images were combined to update the discriminator.
  2. Generator Loss:

    • Includes adversarial loss to fool the discriminator.
    • Cycle consistency loss to ensure that translating an image to another domain and back reconstructs the original image.
    • Identity loss to preserve the original features when mapping an image to the same domain.

Step 3: Training the Discriminator

  • Real images were labeled as 1 (real), and generated images were labeled as 0 (fake).
  • The discriminator loss was computed by averaging the real and fake image losses.

Step 4: Training the Generator

  • The combined CycleGAN model was trained with the generator and discriminator losses, along with cycle consistency and identity losses.

Sample Outputs

During training, I saved the generated images at regular intervals to monitor progress.

Example Outputs at Epoch 100 Generated Horse and Zebra Images at 10th Epoch

Example Outputs at Epoch 100 Generated Horse and Zebra Images at 100th Epoch

Video