openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.64k stars 418 forks source link

PR #17173: [NVIDIA GPU] Optimize collective matmul loops when contracting dim is sharded #17921

Closed copybara-service[bot] closed 1 week ago

copybara-service[bot] commented 1 week ago

PR #17173: [NVIDIA GPU] Optimize collective matmul loops when contracting dim is sharded

Imported from GitHub PR https://github.com/openxla/xla/pull/17173

When contracting dim is sharded, both LHS and RHS will be sharded and the partial dots' results will be accumulated. However the accumulation is done 1 step at a time which leads to a number of individual add kernels, these add kernels will impact overlap between gemms due to:

  1. they each occupy SMs which prevent gemm overlap
  2. they might be fused by gemm_rewriter which introduces data dependency between gemms

This pr optimizes it to get rid of partial accumulations, instead all partials will be concatenated into a contiguous buffer and a global reduction is applied on the buffer. Copybara import of the project:

-- b90997cc33ca3b17ea8e7bbceefd6951a5f442be by TJ Xu tjx@nvidia.com:

Optimize allgather collective matmul loops when contracting dim is sharded

-- a55b2f468d70623eaca90abb175c03d5c475dfd2 by TJ Xu tjx@nvidia.com:

Address pr comments

Merging this change closes #17173

FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17173 from Tixxx:tixxx/cm_contracting a55b2f468d70623eaca90abb175c03d5c475dfd2