Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 61 forks source link

[Tensor Parallelism] Improve comm optimization logic for pair of column-wise parallel linear and row-wise parallel linear #594

Open crcrpar opened 2 weeks ago

crcrpar commented 2 weeks ago

as the todo comment of https://github.com/Lightning-AI/lightning-thunder/blob/24e81cf85f41ccd9dee2d1641a15b91aec65d34d/thunder/tests/distributed/test_tensor_parallel.py#L297-L300 says, currently https://github.com/Lightning-AI/lightning-thunder/blob/24e81cf85f41ccd9dee2d1641a15b91aec65d34d/thunder/distributed/tensor_parallel/optimize_comm.py#L28 is not capable of identifying and removing pairs of post-processing of column-wise parallel linear and pre-processing of row-wise parallel linear on some occasions