huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.76k stars 941 forks source link

Feature Request: Pipeline multiple batches together for Llama3 70B distributed inference #2896

Closed ishan-gaur closed 1 month ago

ishan-gaur commented 3 months ago

Hi all,

I am using Llama3 70B for inference on multiple GPUs using the text generation pipeline. I noticed that it seems only one GPU is used at a time, with one batch being executed as a time. Is there a way to extend pipeline or use accelerate in some way so that batches can be pipelined together so each GPU/shard of the model is executing a different batch at each moment in time?

Happy to help implement this if needed, but would appreciate your guidance on a good approach to do so!

SunMarc commented 3 months ago

Hi @ishan-gaur, we've added support to pippy in accelerate. You can give it a try. More recently, pippy migrated into torch (torch.pipeling and we will make sure to support it when it is released (torch 2.5)

ishan-gaur commented 3 months ago

Thanks @SunMarc will give it a try!

ishan-gaur commented 3 months ago

Sorry, just one quick clarifying question. @SunMarc the PiPPy section here shows the "bubble" during the backwards pass. If I am only doing inference in eval mode, does it automatically pipeline the next batch or does it wait until those four that were initially loaded together are finished?

SunMarc commented 3 months ago

No, it will automatically pipeline the next batch !

ishan-gaur commented 2 months ago

@SunMarc should this be able to work for generation pipelines too?

Not sure if its directly supported, but currently I'm trying to somehow run the forward pass for the model underlying the AutoModelForCausalLM using the same input kwargs as HF's generation code will downstream of my model.generate call. But I'm not sure if these kwargs change over the subsequent calls, which seems to be unsupported.

Also please let me know if there is a more appropriate forum to post such questions. Happy to move this accordingly.

ishan-gaur commented 2 months ago

Is something like this the best way, and then we just write our sampling logic manually?

https://github.com/pytorch/PiPPy/blob/main/examples/llama/pippy_llama.py

SunMarc commented 2 months ago

Hi @ishan-gaur, thanks for asking. Yes, we can definitely discuss this here as other users might want to know. For now, this is better to write your own sampling logic ! pippy is not compatible yet with generate and I'm not sure when it will be the case.

SunMarc commented 2 months ago

Also, I encourage you to add working scripts in accelerate or even in pippy repository, so that others can profit from your learning !

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.