Open hayleyhu opened 6 months ago
Hey @hayleyhu,
Thanks for submitting this issue. To clarify: are you running with the SGMV kernel disabled (latest version of LoRAX should log whether it is enabled or not during initialization)?
When using SGMV, we only run the all_gather and all_reduce once per layer, which is optimal. It is true that when SGMV is disabled we do it once per adapter, which is much slower, but this is primarily a fallback code path that most users shouldn't be encountering.
SGMV is enabled. Still the all_gather and all_reduce time is very high.
Feature request
We noticed that collect_lora_a() is calling all_gather and all_reduce every time.
Do you think you could give a more efficient implementation of this soon? If not, could you give us ideas to implement this on our own?
Thanks.
Motivation
We observe significant time spent on all_reduce and all_gather during lora adapter inference: 60% of first token communication time "with adapter" is spent on all_gather operations while without adapters it's only 0.01% (and the communication is 55% of GPU operations time).
And thus the first token latency is making a big difference on user experience.
Your contribution
We can do the implementation of the advice is clear.