huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.29k stars 25.45k forks source link

add warning when using gradient_checkpointing with FSDP full shard #31578

Open yundai424 opened 4 days ago

yundai424 commented 4 days ago

What does this PR do?

Add a warning when using FSDP full shard with gradient_checkpointing training arg to encourage users to use fsdp config's activation_checkpointing instead.

Fixes #30404

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @SunMarc @muellerzr