ml-jku / MIM-Refiner

A Contrastive Learning Boost from Intermediate Pre-Trained Representations
MIT License
34 stars 3 forks source link

Stage 2 refining does not update weights of encoder. #8

Closed haribaskarsony closed 1 week ago

haribaskarsony commented 1 week ago

Hi, I tried to stage 2 refine a vit model. The loss curves of id heads seemed fine but when i evaluated the refined encoder for a different task it was no different than the base model. When i compared the base model and the refined model , there was no change in the weight values. Is this because "is_frozen" is set to true in the below config file?: https://github.com/ml-jku/MIM-Refiner/blob/main/src/yamls/stage2/l16_d2v2.yaml image

BenediktAlkin commented 1 week ago

Hi, yes this is expected behavior. As the ID heads are random at the start of training, also the ID objective would be random and it would propagate a random signal back to the encoder. Therefore, we first train only the ID heads (with a frozen encoder) in "stage2" and once the ID heads are sufficiently good (after a couple of epochs) we propagate gradients back to the encoder ("stage 3").

This is split into stages, because stage2 needs significantly less compute, so you can run it on a couple GPUs or sometimes even on a single GPU.

For more information, see Appendix D.5 in the MIM-Refiner paper or the MAE-CT paper (Section 3 or Figure 4 describe this).

You can simply reuse your stage2 run by changing the corresponding stage_id in the stage3 yaml. So if your stage2 id is "abcdefg", change the stage_id of this yaml from "xpge1tv6" to "abcdefg"

haribaskarsony commented 1 week ago

Understood, Thank you. I evaluated my base model and the stage 3 refined model(at epoch 2) the performance of the refined model is sub-par (but close) to the base model. could it be an initial performance dip and later the refined model performs better?

BenediktAlkin commented 1 week ago

by default you should have a k-NN monitor during training that evaluates the k-NN accuracy (with only 10% of the imagenet training set). This metric should increase steadily during start of training.

This is how it looks like for MAE-Refined-L/16 image