Day 29: Exploring SimCLR for Self-Supervised Learning - Part 1

Today, I started working on SimCLR, a self-supervised learning approach for representation learning. Unlike traditional supervised learning, SimCLR does not require labels for training. Instead, it leverages contrastive learning to learn meaningful representations by comparing augmented versions of the same image.

Problem Statement

The task was to implement the first part of the SimCLR framework, which involves:

  1. Loading the dataset and preparing it for self-supervised learning.
  2. Defining a robust data augmentation pipeline.
  3. Building the encoder network for feature extraction.
  4. Adding a projection head to map features into a lower-dimensional latent space.
  5. Implementing the contrastive loss function, which is the core of SimCLR.

Dataset

The CIFAR-10 dataset was used for this project. It contains:

  • 50,000 training images and 10,000 test images, each of size 32x32.
  • In self-supervised learning, the labels are not used for training.

Preprocessing:

  • The pixel values of the images were normalized to the range [-1, 1] to align with the requirements of the tanh activation function used in the network.

Code:


# Problem: Work on SimCLR self-supervised learning

import tensorflow as tf
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import ResNet50
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10


# Load the CIFAR-10 dataset
(X_train, _), (X_val, _) = cifar10.load_data()

# Combine the training and test data.
# In Self super-vised training technique we dont need label for tranining
X_data = np.concatenate((X_train, X_val), axis=0)

# Normalize the pixel values between -1 and 1 (Helps with `tanh` activation function)
X_data = (X_data.astype('float32') / 127.5) - 1.0

# Define the Augmentation
def data_augment(image):
    # Random crop and resize
    image = tf.image.random_crop(image, size=[28, 28, 3])
    image = tf.image.resize(image, (32, 32))

    # Random flip (Left-Right)
    imagee = tf.image.random_flip_left_right(image)

    # Color distortion
    image = tf.image.random_brightness(image, max_delta=0.5)
    imagee = tf.image.random_contrast(image, lower=0.1, upper=0.9)

    return image

# Visualize some autmented images
fig, axs = plt.subplots(1, 4, figsize=(10, 3))
for i in range(4):
    image = data_augment(X_data[np.random.randint(len(X_data))])
    axs[i].imshow((image + 1) / 2) # Rescale it back to 0 and 1
    axs[i].axis('off')

plt.show()

# Set up the Base Network (Encoder)
def create_encoder():
    base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(32, 32, 3))
    base_model.trainable = True # We want to train the base model from scratch
    inputs = tf.keras.Input(shape=(32, 32, 3))
    x = base_model(inputs, trainable=True)
    x = GlobalAveragePooling2D()(x)

    return tf.keras.Model(inputs, x)

encoder = create_encoder()
encoder.summary()

# Create Project Head
def create_project_head(encoder):
    inputs = encoder.input
    x = encoder.output
    x = Dense(256, activation='relu')(x)
    output = Dense(128)(x) # Final layer will used for contrastive learning

    return tf.keras.Model(inputs, output)

project_head = create_project_head(encoder)
project_head.summary()

# Define the Contrastive Loss
def contrastive_loss(z_i, z_j, temperature=0.5):
    # Normalize the two vectors
    z_i = tf.math.l2_normalize(z_i, axis=1)
    z_j = tf.math.l2_normalize(z_j, axis=2)

    # Compute cosine scores
    similarity_matrix = tf.matmul(z_i, z_j, transpose_b=True)
    logits = similarity_matrix / temperature

    # Labels and indices of the positive pair
    batch_size = tf.shape(z_i)[0]
    labels = tf.range(batch_size)

    # Calculate the cross entropy loss
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)

    return tf.reduce_mean(loss)

Step 1: Data Augmentation

A robust data augmentation pipeline was defined to create different views of the same image:

  1. Random Cropping and Resizing: Introduces spatial variation.
  2. Random Flipping: Makes the model invariant to left-right flips.
  3. Color Distortion: Randomly adjusts brightness and contrast to encourage the model to focus on semantic features rather than colors.

Here’s a visualization of some augmented images:

Augmented Image

Step 2: Encoder Network

A ResNet-50 model was used as the encoder network:

  • Pre-trained Weights: The imagenet weights were loaded for initialization.
  • Global Average Pooling: The features extracted by ResNet were pooled to reduce their dimensionality.

The encoder serves as the backbone of SimCLR, extracting meaningful representations from the input images.

Step 3: Projection Head

A projection head was added on top of the encoder:

  • Dense Layers:
    • A hidden layer with 256 neurons and ReLU activation.
    • A final output layer with 128 neurons for contrastive learning.

This projection head maps the features from the encoder to a lower-dimensional latent space where contrastive loss is applied.

Step 4: Contrastive Loss Function

The contrastive loss encourages the model to:

  • Bring augmented views of the same image closer together in the latent space.
  • Push apart representations of different images.

Next Steps

  1. Training the SimCLR Framework:

    • Implement the training loop using the encoder, projection head, and contrastive loss.
    • Evaluate the representations learned by SimCLR.
  2. Visualization:

    • Plot the loss curve during training.
    • Use a t-SNE plot to visualize the learned representations in 2D space.

Video