Day 14: Building a Custom CNN-Based Student Model Using a Pre-Trained Teacher Model

Knowledge Distillation: Building a Custom CNN-based Student Model Using a Pre-Trained Teacher Model

In this mini-project, we use Knowledge Distillation to train a smaller student CNN by learning from a larger, pre-trained teacher model. The teacher model helps the student model generalize better, enabling it to achieve comparable performance while being more efficient. Below, you will find a detailed breakdown of the Python code used to implement this process.

Step-by-Step Code Explanation

import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout
import tensorflow.keras.backend as K
from tensorflow.keras.losses import categorical_crossentropy
  • Import Libraries: Here, we import the necessary libraries. We use TensorFlow and Keras to handle deep learning models. We also import ResNet50 as the pre-trained teacher model, along with several layers for building our custom models. We use cifar10 as our dataset, and backend (K) is used to implement the custom loss function.

Step 1: Load and Prepare the Teacher Model

# Load the ResNet50 pre-trained larger model (Teacher)
teacher_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))

# Add custom classification layers on top
x = teacher_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

# Final Teacher Model
teacher_model = Model(inputs=teacher_model.inputs, outputs=predictions)
  • Teacher Model Setup: We use ResNet50 as the teacher model, which is pre-trained on ImageNet.
    • include_top=False: This excludes the default fully connected layers, allowing us to add custom layers.
    • Custom Layers: After obtaining the feature maps from ResNet50, we add:
      • GlobalAveragePooling2D: Reduces the dimensionality of the feature maps.
      • Dense Layer (256 units): A dense layer with ReLU activation.
      • Dense Layer (10 units, softmax): Outputs probabilities for 10 classes.
# Compile the teacher model
teacher_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)
  • Compile Teacher Model: The teacher model is compiled using the Adam optimizer, sparse categorical cross-entropy as the loss, and accuracy as the metric.

Step 2: Load and Preprocess the Dataset

# Load the dataset
(X_train, y_train), (X_val, y_val) = cifar10.load_data()

# Normalize the pixel values between 0-1
X_train = X_train / 255.0
X_val = X_val / 255.0

# Resize the images to 128x128 to match the teacher's input shape
X_train = tf.image.resize(X_train, (128, 128))
X_val = tf.image.resize(X_val, (128, 128))
  • Dataset Preparation: The CIFAR-10 dataset is loaded, which contains 32x32 pixel images.
    • Normalization: The images are normalized to have values between 0-1.
    • Resizing: Since ResNet50 requires input dimensions of 128x128, we resize the images using tf.image.resize().

Step 3: Train the Teacher Model

# Train the teacher model.
teacher_model.fit(
    X_train,
    y_train,
    validation_data=(X_val, y_val),
    epochs=10,
    batch_size=32,
    verbose=1,
)

teacher_output = teacher_model.predict(X_val)
  • Train Teacher Model: The teacher model is trained for 10 epochs with a batch size of 32. After training, we compute the teacher output (predictions on validation data), which will be used to guide the student model.

Step 4: Define the Student Model

# Define the simple smaller CNN as student model
student_model = Sequential()

student_model.add(Dense(16, (3, 3), activation='relu', input_shape=(128, 128, 3)))
student_model.add(MaxPooling2D(3, 3))
student_model.add(Dense(32, (3, 3), activation='relu'))
student_model.add(MaxPooling2D(3, 3))
student_model.add(Dense(64, (3, 3), activation='relu'))
student_model.add(Flatten())
student_model.add(Dense(128, activation='relu'))
student_model.add(Dropout(0.5))
student_model.add(Dense(10, activation='softmax'))

student_model.summary()
  • Student Model Definition: The student model is defined to be simpler and smaller than the teacher model. It consists of multiple convolutional layers with MaxPooling to reduce dimensionality and a Dropout layer to avoid overfitting.
  • Dense(10, activation='softmax'): The output layer predicts probabilities for 10 classes.

Step 5: Define Distillation Loss

def distillation_loss(org_prediction, student_prediction, teachers_output, temperature=0.3, alpha=0.5):
    # Calculate the student loss using standard categorical cross entropy
    student_loss = categorical_crossentropy(org_prediction, student_prediction)

    # Student, Teacher soft target distribution
    teachers_soft = K.softmax(teachers_output / temperature)
    student_soft = K.softmax(student_prediction / temperature)

    distillation_loss = K.sum(teachers_soft * K.log(teachers_soft / student_soft))

    # Combined loss: Kullback-Leibler divergence
    return alpha * student_loss + (1 - alpha) * distillation_loss
  • Distillation Loss Function: This function calculates the combined loss for training the student model.
    • student_loss: The standard categorical cross-entropy between the true labels and the student’s predictions.
    • Soft Targets: The teacher’s output is softened using a temperature parameter. This helps convey more nuanced information to the student.
    • Distillation Loss: Calculated using the Kullback-Leibler divergence between the teacher’s softened output and the student’s output.
    • Combined Loss: Combines the student loss with the distillation loss, weighted by alpha.

Step 6: Compile and Train the Student Model

# Compile the student model
student_model.compile(
    optimizer='adam',
    loss=lambda org_prediction, student_prediction: distillation_loss(org_prediction, student_prediction, teacher_output),
    metrics=['accuracy']
)

# Train the model
student_model.fit(
    X_train,
    y_train,
    validation_data=(X_val, y_val),
    epochs=10,
    batch_size=32,
    verbose=1
)
  • Compile Student Model: The student model is compiled using the custom distillation loss function defined earlier.
    • The lambda function ensures that the distillation loss uses both the original predictions and the teacher’s output.
  • Train Student Model: The student model is trained with 10 epochs and a batch size of 32. During training, it learns both from the true labels and from the teacher model’s predictions to improve generalization.

Video