Open rsomani95 opened 1 year ago
Are you fine-tuning for classification tasks or continuing to train on image-text pairs? In any case one other thing to try is linearly interpolating the weights before and after fine-tuning -- you may find this reduces catastrophic forgetting. i.e., if you have state dicts sd1
and sd2
try loading {k : (1 - alpha) * sd1[k] + alpha * sd2[k] for k in sd1.keys()}
where alpha
is some number between 0 and 1.
@mitchellnw thanks for your response. I'm training on image-text pairs. Thanks for the idea re. interpolation, I'll definitely give that a shot and report back my findings.
The weight interpolation suggestion was super helpful.
In the graph below, alpha=0.0
is the pre-trained model and alpha=1.0
is the fully finetuned model. Turns out an alpha of 0.4 goes a long way. What's shown here are validation scores across 19 downstream datasets.
I'm yet to test on ImageNet, will be doing that next.
Great to hear! In case your interested some more background on that trick here: https://arxiv.org/abs/2109.01903
In my fine-tuning experiments, I've run into catastrophic forgetting and was wondering if using EMA would help mitigate this. I'm not sure if it makes sense to do this to the image encoder alone, or both the text + image encoder.
If it makes sense, I'd love to try implementing this with some guidance.