Day 30: SimCLR - Self-Supervised Learning Part 2 and Classifier Training
Today marks the final day of the 30 Days 30 Machine Learning Projects Challenge. I completed the SimCLR self-supervised learning framework by training it on the CIFAR-10 dataset and evaluating the learned representations using a simple classifier.
SimCLR Recap from Day 29
In Day 29, I implemented the foundational components of SimCLR:
- Data Augmentation: Generated diverse views of the same image.
- Encoder Network: Extracted meaningful features using a ResNet-50 backbone.
- Projection Head: Mapped the features to a lower-dimensional space for contrastive learning.
- Contrastive Loss: Learned representations by pulling augmented views of the same image closer and pushing others apart.
Day 30: Training and Evaluation
Today, I focused on:
- Training SimCLR to learn representations.
- Evaluating the learned representations using a simple classifier.
Step 1: Dataset Preparation for Training
The CIFAR-10 dataset was prepared by:
- Normalization: Rescaled pixel values to the range
[-1, 1]
. - Augmentation Pipeline: Applied random crops, flips, and color distortions to generate diverse views of the data.
- Batching and Prefetching: Optimized the dataset pipeline for training efficiency.
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 128
BUFFER_SIZE = 10000
def prepare_data(x_data):
dataset = tf.data.Dataset.from_tensor_slices(x_data)
dataset = dataset.shuffle(BUFFER_SIZE).map(data_augment, num_parallel_calls=AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
return dataset
dataset = prepare_data(X_data)
Step 2: Training SimCLR
The SimCLR framework was trained using the contrastive loss function:
- Two Augmented Views: For each batch, two augmented views of the same image were generated.
- Forward Pass: Both views were passed through the encoder and projection head to generate latent representations.
- Contrastive Loss: Encouraged the model to pull similar views closer in the latent space while pushing others apart.
- Gradient Updates: The model parameters were updated using the Adam optimizer.
optimizer = Adam(learning_rate=0.0003)
# Training Steps
@tf.function
def train_steps(batch):
# Generate two augmented views of the image
augmented_1 = tf.map_fn(data_augment, batch)
augmented_2 = tf.map_fn(data_augment, batch)
with tf.GradientTape() as tape:
# Encoder and project both
z_i = project_head(augmented_1, training=True)
z_j = project_head(augmented_2, training=True)
# Calculate the contrastive loss
loss = contrastive_loss(z_i, z_j)
# Apply the gradient
gradients = tape.gradient(loss, project_head.trainable_variables)
optimizer.apply_gradients(zip(gradients, project_head.trainable_variables))
return loss
The model was trained for 10 epochs, and loss was monitored for each epoch.
EPOCHS = 10
for epoch in range(EPOCHS):
epoch_loss_avg = tf.keras.metrics.Mean()
for batch in dataset:
loss = train_steps(batch)
epoch_loss_avg.update_state(loss)
print(f"Epoch: {epoch + 1}, loss: {epoch_loss_avg.result().numpy()}")
Step 3: Evaluating Representations with a Classifier
To evaluate the representations learned by SimCLR:
- Encoder Freezing: The encoder was frozen to prevent further updates.
- Simple Classifier:
- A dense layer with a softmax activation was added on top of the encoder.
- This layer mapped the learned features to the 10 CIFAR-10 classes.
- Training the Classifier:
- The CIFAR-10 dataset was split into training (80%) and validation (20%) subsets.
- The classifier was trained for 5 epochs using the frozen encoder features.
# Evaluate the Model
encoder.trainable = False
# Create a simple classifier
classifier = tf.keras.Sequential(
encoder,
Dense(10, activation='softmax')
)
classifier.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Split into training (80%) and validation (20%) dataset
X_train, X_val = train_test_split(X_data, test_size=0.2, random_state=42)
# Train the classifier
classifier.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=5,
batch_size=128
)
This concludes the 30 Days 30 Machine Learning Projects Challenge.