ml-jku / MIM-Refiner

A Contrastive Learning Boost from Intermediate Pre-Trained Representations
MIT License
34 stars 3 forks source link

train batch_size #9

Closed haribaskarsony closed 6 days ago

haribaskarsony commented 1 week ago

Hi, what is the minimum batch size that needs to be maintained for stage 2 and stage 3 refining?

I have resources for batch_size-1024 stage2 refining but i'm unable to do the same for stage 3.

i was wondering if we have to kwwp the batch_size same for both stage2 and stage3? if so what is the minimum batch_size that wouldn't affect the ID heads learning capability.

Dataset: 27k classes, with 20 images in each class.

BenediktAlkin commented 1 week ago

We always kept batchsize the same for stage2 and stage3, but that is not strictly necessary.

You can also go lower than 1024, for example we used 512 for the ViT-H and ViT-2B experiments and it worked quite well. However, note that batchsize and temperature are related hyperparameters because there are more negative samples in the loss. For ViT-H and ViT-2B higher temperatures are good (0.3 for ViT-H and 0.35 for ViT-2B)

As the best temperature can vary anyways based on the dataset/number of classes, I'd suggest to simply try a smaller batchsize (512 or 256) with temperature 0.2 to get a baseline result. If that model is not good enough you can then try to increase batchsize and tweak the temperature.

In my experience, the temperature and batchsize are not "make-or-break" parameters, and as long as you they are in a reasonable range (around temperature=0.2 and batchsize around 512) the model should converge just fine; but if you want optimal performance you'd have to tune it.

haribaskarsony commented 6 days ago

Thanks for the detailed information. Where do i adjust the temperature for the NN objective? I coundln't find it in the config files.

BenediktAlkin commented 6 days ago

each yaml file is "preprocessed" by either the stage2_processor or stage3_processor to make the config file cleaner

each head is named something with "_temp02", the 02 stands for 0.2 and temp for temperature. So renaming it to temp025 would change the temperature from 0.2 to 0.25

haribaskarsony commented 6 days ago

Thank you so much.