predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.12k stars 139 forks source link

Efficient implementation of all_reduce and all_gather for collect_lora_a #360

Open hayleyhu opened 6 months ago

hayleyhu commented 6 months ago

Feature request

We noticed that collect_lora_a() is calling all_gather and all_reduce every time.

Do you think you could give a more efficient implementation of this soon? If not, could you give us ideas to implement this on our own?

Thanks.

Motivation

We observe significant time spent on all_reduce and all_gather during lora adapter inference: 60% of first token communication time "with adapter" is spent on all_gather operations while without adapters it's only 0.01% (and the communication is 55% of GPU operations time).

And thus the first token latency is making a big difference on user experience.

Your contribution

We can do the implementation of the advice is clear.

tgaddair commented 6 months ago

Hey @hayleyhu,

Thanks for submitting this issue. To clarify: are you running with the SGMV kernel disabled (latest version of LoRAX should log whether it is enabled or not during initialization)?

When using SGMV, we only run the all_gather and all_reduce once per layer, which is optimal. It is true that when SGMV is disabled we do it once per adapter, which is much slower, but this is primarily a fallback code path that most users shouldn't be encountering.

hayleyhu commented 6 months ago

SGMV is enabled. Still the all_gather and all_reduce time is very high.