Reimplementation of the method from the paper - MaskTune: Mitigating Spurious Correlations by Forcing to Explore
This project is a reimplementation of MaskTune, a novel technique described in the paper MaskTune: Mitigating Spurious Correlations by Forcing to Explore. This single-epoch finetuning technique addresses the challenge of spurious correlations in over-parametrized deep learning models. Spurious correlations refer to coincidental associations between input and target variables that can lead to poor generalization performance [6]. It forces the model to explore other train variables by concealing the first explored ones, causing the training to ditch its myopic and greedy feature-seeking character, while encouraging exploration, leveraging more input variables.
This project is implemented on a modified MNIST dataset. One base model was trained on this dataset and several fine-tuned ones with several masking approaches. You can access the model checkpoints from the /checkpoints
directory. To load those checkpoints in the notebook, change the directory
variable to the desired one, which will modify the root directory location.
The pipeline for the project is as follows:
The appropriate dataset is created from MNIST to illustrate the effectiveness of the technique further. At first, we distinguish between two MNIST digit groups (0-4 and 5-9). Those groups are remapped into class 0 and class 1 respectively. We induce a spurious feature (blue square in the top left corner) to 99% of samples in newly acquired class 0 and 1% of the samples of the new class 1. As for testing, raw and modified, biased test sets are used (both of them remapped as well). The modified test set has a spurious feature for only class one.
This project uses the same Convolutional Neural Network as the architecture. One feature of the SmallCNN
class is a get_grad_cam_target_layer function which will grab the last convolutional layer and use it for the saliency map generation.
The hyperparameters are the same as suggested by the authors:
lr = 0.01
momentum = 0.9
weight_decay = 1e-4
batch_size = 128
epochs = 20
lr_decay_epochs = 25
lr_decay_factor = 0.5
number_of_classes = 2
The masking function ๐ is a key factor in the MaskTune method. It identifies and masks the most discriminative features in the sample found by the fully trained model, thus it is applied offline. This will encourage the model to explore more features during the fine-tuning.
$๐:$ masking function, here xGradCAM is used.
For each sample $(x_i, y_i)$, $x_i โX$ and $y_iโY$, the masking is done as following:
Where ฮค is a thresholding function with the threshold factor ฯ $(i.e., ฮค=๐{๐{x_i}โคฯ})$ and โ denotes element-wise multiplication.
$ฮค(๐_{x_i})$, in our case [8, 8] is upsampled to match the size of the input [3, 28, 28].
The steps are the following:
This project experiments with 3 different masking methods, all of which leverage the saliency maps.
First, we get the checkpoint for the ERM model, using the cross-entropy loss function and stochastic gradient descent optimizer. The training is done for 50 epochs (Due to limited resources) The learning rate decays after every number of specified epochs. The final learning rate value from ERM training is used as a finetuning hyperparameter later on.
The models with the different masking methods and parameters are finetuned in the Masking and Finetuning section, where we can define the desired configuration with method
and param
variables. Finetuning models one by one enables us to save and load many checkpoints without RAM bottlenecking, to plot the method's effectiveness.
The fine-tuned model checkpoints are then saved with the appropriate namings, which are later leveraged to plot out their performances.
Each model is tested on raw and biased test datasets.
In the Plotting and Visualization section, we can modify the base_model
and finetuned_model
variables to output the saliency maps and masks for the desired finetuned model checkpoint.
From this plot, It is visible that MaskTune is a viable method, able to boost performance significantly. However, the parameters and the masking techniques should be selected appropriately. In our case, Threshold methods with moderately high parameters and top_k methods with small and moderate parameter values performed the best. Mean masking with param=0.9 performed well on the biased set, but poorly on the raw test set. Overall the best performer was top_k_0.1 on both, biased and raw test sets.
Here are the saliency maps for the worst performing model: Mean masking with 0.1 threshold:
While it masks the spurious features, the model fails to classify non-spurious samples
And the best performing one: top_k_0.1
We can tell that this model can generalize better and perform well when encountering samples with spurious features.
Normalization and Masking Techniques:
Dataset Handling and Masking Function Implementation:
Grad-CAM Implementation and Handling:
General PyTorch Implementation and Training Loop:
Visualization and Plotting:
Masking Techniques in Spurious Correlation Mitigation:
MNIST Dataset Details:
Loading and Plotting MNIST Dataset:
Writing a Training Loop in TensorFlow:
Grad-CAM for CNN Visualization:
Bar Plot in Matplotlib: