huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

Type of Accelerator.distributed_type() might be wrong #3205

Closed ffrancesco94 closed 3 weeks ago

ffrancesco94 commented 3 weeks ago

Hi, I am trying to run one of the slurm examples, specifically multi-node GPU training. I tried to pick two GPUs on two different nodes as a small test. If, during the complete_nlp_example.py script I make it print the number of GPUs with accelerator.num_processes and the type with accelerator.distributed_type, I get 2 (correctly), but the type is DistributedType.MULTI_CPU. Is this just a case of the multi-node "shadowing" the fact that each node actually has GPUs or does it mean that the model is actually running on CPUs? Thank you!

muellerzr commented 3 weeks ago

It would mean it's actually running on CPU. You can verify this by checking the device of the model or the data coming in

ffrancesco94 commented 3 weeks ago

Thank you! So to be clear, whether the GPUs sit all on one node or they're distributed, the type should be always MULTI_GPU?

PS: checking the device would still mean doing it through the Accelerator object or how else? Thank you very much!

On Wed, 30 Oct 2024, 18:34 Zach Mueller, @.***> wrote:

It would mean it's actually running on CPU. You can verify this by checking the device of the model or the data coming in

— Reply to this email directly, view it on GitHub https://github.com/huggingface/accelerate/issues/3205#issuecomment-2447895437, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGGGVZNFUQLGY2GJIGVC53LZ6EKA5AVCNFSM6AAAAABQ4RQPOCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINBXHA4TKNBTG4 . You are receiving this because you authored the thread.Message ID: @.***>

muellerzr commented 3 weeks ago

Correct.

You can also do from accelerate import PartialState; print(PartialState().device)

For a quick one-liner

ffrancesco94 commented 3 weeks ago

Ok, seems like I'm hitting some kind of bug, opening another thread for that though.