merveenoyan / smol-vision

Recipes for shrinking, optimizing, customizing cutting edge vision models. 💜
Apache License 2.0
859 stars 80 forks source link

Out of memory error running on 2 A100s with 80 GB each #8

Open joris-sense opened 2 months ago

joris-sense commented 2 months ago

Hey, thanks for creating these notebooks! But I am trying to run Idefics_FT, and unfortunately, it isn't working... I run into an out of memory error when calling trainer.train() even though I am running on a machine with 2 A100s with 80 GB each. Any idea what the problem could be? I did make sure to not initialize models twice, and just loading the model seems to take only about 10 GB as expected.

The trainer seems to get up to step 3/2144 and the cell running trainer.train() produces the following additional warnings:

/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
/opt/conda/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
merveenoyan commented 1 month ago

@joris-sense I have added some lines to properly setup the CUDA devices to both script and the notebook, can you check? Warnings is fine because 8-bit is not CUDA native precision and has to be cast back and forth, which adds a bit of a delay to training duration. Do you get OOM?

joris-sense commented 1 month ago

The cell

for param in model.model.vision_model.parameters():
    param.requires_grad = False 

needs to be moved beyond where the model is defined. Then, I tried running it on 2xA100 again (I was renting this from vast), and I get an error about not finding FlashAttention even though the command installing it was run and the kernel was reloaded. I'll stop investigating this now because it does work for me with H100 (using quanto instead of BitsAndBytes), and these are actually cheaper on vast than A100...

ImportError                               Traceback (most recent call last)
Cell In[12], line 43
     41     print(model.get_nb_trainable_parameters())
     42 else:
---> 43     model = Idefics3ForConditionalGeneration.from_pretrained(
     44         model_id,
     45         torch_dtype=torch.bfloat16,
     46         _attn_implementation="flash_attention_2",
     47     ).to(DEVICE)
     49     # if you'd like to only fine-tune LLM
     50     for param in model.model.vision_model.parameters():

File /opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:3827, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3824     init_contexts.append(init_empty_weights())
   3826 config = copy.deepcopy(config)  # We do not want to modify the config inplace in from_pretrained.
-> 3827 config = cls._autoset_attn_implementation(
   3828     config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
   3829 )
   3831 with ContextManagers(init_contexts):
   3832     # Let's make sure we don't run the init function of buffer modules
   3833     model = cls(config, *model_args, **model_kwargs)

File /opt/conda/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:728, in Idefics3PreTrainedModel._autoset_attn_implementation(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map, **kwargs)
    715 @classmethod
    716 def _autoset_attn_implementation(
    717     cls,
   (...)
    723     **kwargs,
    724 ):
    725     """
    726     Overrides the method in `PreTrainedModel` to update the vision config with the correct attention implementation
    727     """
--> 728     config = super()._autoset_attn_implementation(
    729         config=config,
    730         use_flash_attention_2=use_flash_attention_2,
    731         torch_dtype=torch_dtype,
    732         device_map=device_map,
    733         check_device_map=check_device_map,
    734         **kwargs,
    735     )
    736     config.vision_config._attn_implementation = config._attn_implementation
    737     return config

File /opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:1557, in PreTrainedModel._autoset_attn_implementation(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)
   1554     config._attn_implementation = "flash_attention_2"
   1556 if config._attn_implementation == "flash_attention_2":
-> 1557     cls._check_and_enable_flash_attn_2(
   1558         config,
   1559         torch_dtype=torch_dtype,
   1560         device_map=device_map,
   1561         hard_check_only=False,
   1562         check_device_map=check_device_map,
   1563     )
   1564 elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
   1565     # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
   1566     config = cls._check_and_enable_sdpa(
   1567         config,
   1568         hard_check_only=False if requested_attn_implementation is None else True,
   1569     )

File /opt/conda/lib/python3.11/site-packages/transformers/modeling_utils.py:1668, in PreTrainedModel._check_and_enable_flash_attn_2(cls, config, torch_dtype, device_map, check_device_map, hard_check_only)
   1664         raise ImportError(
   1665             f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
   1666         )
   1667     else:
-> 1668         raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
   1669 elif torch.version.hip:
   1670     if flash_attention_version < version.parse("2.0.4"):

ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: Flash Attention 2 is not available. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.
merveenoyan commented 1 month ago

Yes it is defined within the case where fine-tuning is done and not under LoRA part for that reason.

I am currently a bit busy (this is a side project at work) I will define requirements file for each of the script, but you should install flash attention with pip install flash-attn

joris-sense commented 1 month ago

That code block appears twice in the notebook now, the first time before "We will load VQAv2 dataset"... that's the part where it breaks.

Yeah I tried installing flash-attn in several variants (the Jupyter notebook contains a !pip install command for that as well; these parts seemed to work for me before), somehow it didn't work anyway this time.

Anyway thanks for updating these so far!

merveenoyan commented 1 month ago

@joris-sense hmm interesting, maybe there's an issue with the env can you try to do

import flash_attn
print(flash_attn.__version__)

if nothing comes up, can you check if flash-attn exists by running pip freeze?

joris-sense commented 1 month ago

Sorry, I already destroyed the 2xA100 instance now... Anyway, it (at least my current NB that is based on your previous version) does work perfectly fine on a H100 now, so I guess I'll pass the torch to whoever else finds this issue with a similar setup.

Am Mi., 11. Sept. 2024 um 12:32 Uhr schrieb Merve Noyan < @.***>:

@joris-sense https://github.com/joris-sense hmm interesting, maybe there's an issue with the env can you try to do

import flash_attnprint(flash_attn.version)

— Reply to this email directly, view it on GitHub https://github.com/merveenoyan/smol-vision/issues/8#issuecomment-2343265939, or unsubscribe https://github.com/notifications/unsubscribe-auth/BKT7OJNIKK4B3WGLJK7INILZWAL2VAVCNFSM6AAAAABNWQTQISVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGNBTGI3DKOJTHE . You are receiving this because you were mentioned.Message ID: @.***>

merveenoyan commented 1 month ago

@joris-sense thanks a lot for taking time to open the issues 💗 I think it happens from time to time, above is how I debug

joris-sense commented 1 month ago

On another machine, I got that "flash attention 2 is not available" error again, and commenting out os.environ["CUDA_VISIBLE_DEVICES"] = "4" made it work for me (in the configuration USE_LORA=USE_QLORA=False) Commenting out the line where it specifies that flash attention should be used yielded a "no GPU found" error.