huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.22k stars 26.84k forks source link

Gradient checkpointing warning #32576

Closed BigDataMLexplorer closed 1 month ago

BigDataMLexplorer commented 2 months ago

System Info

Who can help?

@ArthurZucker @muellerzr @sunma

Information

Tasks

Reproduction

Hi, I need help regarding gradient checkpoining settings for a fine tuning LLM model. I want to use it for less gpu memory usage. The system info lists the system information and library versions.

I am doing a text classification task using the AutoModelForSequenceClassification class with the Llama3 8b model. I load the model, then prepare the model for kbit trainig, use LoRA technique using LoraConfig and get_peft_model and use gradient_checkpointing=True in Huggingface Trainer.

Without gradient_checkpointing=True the training takes 9:40 hours and has about 84% accuracy. If I use gradient_checkpointing=True, the training takes about 4:47 hours and has only 70% accuracy. If I specify gradient_checkpointing=True in Trainer, I get these warnings:

env/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: 
UserWarning: torch.utils.checkpoint: the use_reentrant 
parameter should be passed explicitly. In version 2.4 we 
will raise an exception if use_reentrant is not passed. 
use_reentrant=False is recommended, but if you need 
to preserve the current default behavior, you can pass 
use_reentrant=True. Refer to docs for more details on the
 differences between the two variants.

**warnings.warn(
env/lib/python3.9/site-packages/torch/utils/checkpoint.py:91: 
UserWarning: None of the inputs have requires_grad=True. 
Gradients will be None
  warnings.warn(**

Thanks for any help

Expected behavior


amyeroberts commented 2 months ago

Hi @BigDataMLexplorer, thanks for reporting. Could you try the solution suggested here: https://github.com/huggingface/transformers/issues/26969#issuecomment-1807831645

ArthurZucker commented 2 months ago

Hey, also a bit wird, as we have this code: https://github.com/huggingface/transformers/blob/e683c378fff90cb6c986e1f80684bc3e5ed3cda5/src/transformers/modeling_utils.py#L2362-L2365

which should always use re-entrant but allow you to set it to False:


model.gradient_checkpointing_enable({"use_reentrant":False})`
github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.