microsoft / vidur

A large-scale simulation framework for LLM inference
MIT License
283 stars 45 forks source link

Clarifications on All_Gather Handling in Profiling Communication Operators #46

Open zhhangBian opened 4 days ago

zhhangBian commented 4 days ago

Hello Vidur,

Thank you for sharing your work. While reading the code and documentation, I encountered some questions related to the Profiling Communication Operators mentioned in the paper.

In the paper, it is noted that there are three collective operations: all_reduce, all_gather, and send_recv. However, in the simulated device data located at data/compute, it seems that simulation parameters are provided only for all_reduce and send_recv. There are no simulation parameters for the all_gather operation.

After reviewing the relevant code in vidur/profiling, it appears that all_gather is treated as device-independent, and thus its parameters are not explicitly introduced. However, isn’t all_gather typically device-dependent? If so, could you clarify why it is treated as device-independent in this case?

Additionally, in vidur/profiling/collectives/main.py, the --collective argument only supports choices=["all_reduce", "send_recv"]. Could you explain the rationale behind excluding all_gather as an option here?

The above are my points of confusion while going through the code. I would greatly appreciate it if you could provide clarification or corrections if I have misunderstood any part of your work.

Thank you in advance for your time and insights!

AgrawalAmey commented 4 days ago

Hi @zhhangBian, we originally included all gather to represent some parallel strategies that we were experimenting with. However, as of today we actually only use all reduce and send/recv operations -- which are sufficient to represent tensor and pipeline parallelism.

zhhangBian commented 3 days ago

@AgrawalAmey Thank you for your response!

From my understanding, performing both a row-partition and a column-partition would involve using all-reduce, while performing only a row-partition or only a column-partition would require an all-gather operation.

However, I’m still curious. Could you kindly elaborate further on what you mean by “only use all-reduce and send/recv operations -- which are sufficient to represent tensor and pipeline parallelism”? Could you also explain the underlying principles and how this is implemented?

Thank you so much for your help!