When contracting dim is sharded, both LHS and RHS will be sharded and the partial dots' results will be accumulated. However the accumulation is done 1 step at a time which leads to a number of individual add kernels, these add kernels will impact overlap between gemms due to:
they each occupy SMs which prevent gemm overlap
they might be fused by gemm_rewriter which introduces data dependency between gemms
This pr optimizes it to get rid of partial accumulations, instead all partials will be concatenated into a contiguous buffer and a global reduction is applied on the buffer.
Copybara import of the project:
--
b90997cc33ca3b17ea8e7bbceefd6951a5f442be by TJ Xu tjx@nvidia.com:
Optimize allgather collective matmul loops when contracting dim is
sharded
--
a55b2f468d70623eaca90abb175c03d5c475dfd2 by TJ Xu tjx@nvidia.com:
PR #17173: [NVIDIA GPU] Optimize collective matmul loops when contracting dim is sharded
Imported from GitHub PR https://github.com/openxla/xla/pull/17173
When contracting dim is sharded, both LHS and RHS will be sharded and the partial dots' results will be accumulated. However the accumulation is done 1 step at a time which leads to a number of individual add kernels, these add kernels will impact overlap between gemms due to:
This pr optimizes it to get rid of partial accumulations, instead all partials will be concatenated into a contiguous buffer and a global reduction is applied on the buffer. Copybara import of the project:
-- b90997cc33ca3b17ea8e7bbceefd6951a5f442be by TJ Xu tjx@nvidia.com:
Optimize allgather collective matmul loops when contracting dim is sharded
-- a55b2f468d70623eaca90abb175c03d5c475dfd2 by TJ Xu tjx@nvidia.com:
Address pr comments
Merging this change closes #17173
FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17173 from Tixxx:tixxx/cm_contracting a55b2f468d70623eaca90abb175c03d5c475dfd2