Open parthmannan opened 6 months ago
Thank you, Parth, for this excellent analysis and accompanying screenshots!
AllGather operations during the forward pass are launched before the computation begins.
At some point, our sorting broke and we need to restore the intended functionality, here's the issue for this: https://github.com/Lightning-AI/lightning-thunder/issues/277
Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?
Yes, there's a better way, if we used a special interleaving copy it should be possible to do fewer copies and more views. We don't have an issue tracking this, but creating microbenchmarks for bucketing is in our plans.
Update: ZeRO2 AllGather overlap issues were fixed in #383 and the performance is looking much better now.
🐛 Bug
This is a lengthy issue/post detailing my observations with our distributed and bucketing performance. Some of these are actionable items and some are just observations to be aware of.
FSDP ZeRO2 (Bucketing=None)
AllGather operations during the forward pass are launched before the computation begins. This is because the Thunder trace schedules the AllGather all before the computation and also calls the
wait
operators before any compute begins. The long line of operations in stream22 are all AG kernels. This is bad for performance because -FSDP ZeRO2 (Bucketing=Block)
CatArrayBatchedCopy_contig
kernels for each AllGather operation which would prevent overlap with compute even if the AG launch schedule was interleaved.direct_copy_kernel
kernels as well when using larger bucketing. This would also prevent any overlap with compute and degrade performance. My understanding is that this might be because we concat the parameters, copy them into a buffer before communication.Is there a better way of allocating these buffers only in the first iteration and using portion of these buffers for the computation instead of concat+copy every iteration?
For example, below is the execution timeline for
TorchInductor
FSDP ZeRO3 (Bucketing=None)
When using ZeRO3, the schedule is as expected with the AG kernels and compute kernels being interleaved. However, due to launch overheads and small message sizes without bucketing, there are many gaps where the compute is not being overlapped with communication. There is probably room for improvement in the launch overheads (maybe the schedule even?) to improve performance but there is no fundamental bug here. This is just an observation.
FSDP ZeRO3 (Bucketing=Block)
direct_copy_kernel
kernels being launched with AG that adds overhead to the compute stream.I am writing all of this here to have an easy comparison of all the options tried and facilitate discussion. Please let me know if some of these require individual issues to track and I can create those.
cc @carmocca @awaelchli @crcrpar @IvanYashchuk @mruberry @t-vi @tfogal