mlfoundations / open_clip

An open source implementation of CLIP.
Other
10.24k stars 979 forks source link

SigLIP memory issue with a large batch size #765

Closed airogachev closed 6 months ago

airogachev commented 10 months ago

Original paper clamed to use big batches. Using current implementation I face the problem that if I increase the batch size even to 1024, it fails on the second iteration. I use 4 cards with 44.5 GB of video memory. So, it seems that memory may be filled with something that was not cleaned after the first batch? Any ideas?

File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jovyan/itm-v2/open_clip/src/open_clip/model.py", line 254, in forward
    text_features = self.encode_text(text, normalize=True) if text is not None else None
  File "/home/jovyan/itm-v2/open_clip/src/open_clip/model.py", line 241, in encode_text
    x = self.transformer(x, attn_mask=self.attn_mask)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jovyan/itm-v2/open_clip/src/open_clip/transformer.py", line 321, in forward
    x = r(x, attn_mask=attn_mask)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jovyan/itm-v2/open_clip/src/open_clip/transformer.py", line 243, in forward
    x = x + self.ls_2(self.mlp(self.ln_2(x)))
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/code-server/lib64/python3.8/site-packages/torch/nn/modules/activation.py", line 684, in forward
    return F.gelu(input, approximate=self.approximate)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 308.00 MiB (GPU 1; 44.49 GiB total capacity; 40.23 GiB already allocated; 233.69 MiB free; 42.97 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
rwightman commented 10 months ago

@airogachev not sure why you think this is an issue or bug vs the model being too large for the batch size? siglip is improved scaling over clip but it isn't magic at the larger model sizes you only get to the 32k total batch sizes by using lots of gpus for big models, it just needs fewer GPUs than equivalent global batch size for clip and results appear a bit better at a lower global batch size.

I think the 224/256 resolution B/16 models should be able to do 1024 on ~24-32GB of memory based on the paper claims wrt to TPU-v4. I don't believe they published details of what their total TPU count or per-device batch size was for the larger models.

Make sure gradient checkpointing is enabled, use amp with bfloat16, etc.