ml-jku / MIM-Refiner

A Contrastive Learning Boost from Intermediate Pre-Trained Representations
MIT License
36 stars 3 forks source link

Largest batch size for stage3 training #11

Closed yaoh11-gne closed 1 month ago

yaoh11-gne commented 1 month ago

Hi! I am trying to set up stage 3 MIM-Refiner training with images of size (3, 224, 224) and found that I have to reduce the batch size to 64 to avoid OOM errors. I am using an A100 with 80 GB of memory. Is this an expected batch size?

I am asking because I saw that the default value for the batch size in stage 3 is 1024, which is also the value provided in the MIM-Refiner paper. So I thought it would be better to double-check in case I set up any configs wrong.

Thanks!

yaoh11-gne commented 1 month ago
BenediktAlkin commented 1 month ago

Hi!

64 seems quite low as batchsize. We use 1024 for ViT-L and 512 for the larger models. If I'd have to guess, something like 256 would probably also work, but I think going even lower could impair performance quite a bit.

If memory is a limit, you can freeze the first couple of blocks and use less ID heads which only degrade performance slightly (e.g. 4 heads instead of 8, Table 1a in the paper; freezing 12 blocks, Table 8 in the paper). With these two changes you can then max out the available GPU memory (you should be able to fit around batchsize 100 into 80GB without freezing/less heads). With these changes you should be able to fit a batchsize of around 256 in 80GB GPU memory which should be fine for training.

We also freeze the first 6 blocks for ViT-2B because of memory issues. Simply add a vit_block_freezer to your yaml (see here).

If none of these suggestions suffice, you can reduce the number of local crops to like 8 or 6 (its 10 by default). This can be done like this:

datasets:
  train:
    template: ${yaml:datasets/imagenet/train_byolaug_localviews}
    template.vars.version: imagenet1k
    template.vars.n_local_views: 6

Some additional techniques to limit memory consumption (that are not implemented in this codebase) would be apply light masking (e.g. discard 25% of the image patches in the 2 global views) or gradient checkpointing.

My approach would be to freeze 6 layers, use 4 heads and try to train with that batchsize (I'd guess ~150). My second try would be to freeze 12 layers, use 4 heads, 6 local crops and the maximum possible batchsize with that (I'd guess ~250).

yaoh11-gne commented 1 month ago

Thanks a lot for your suggestions. They are very helpful!