apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.82k stars 3.48k forks source link

[TIR] Enhance `LowerThreadAllreduce` pass to automatically infer shared memory scope #17442

Closed LeiWang1999 closed 1 week ago

LeiWang1999 commented 1 month ago

The pass LowerThreadAllreduce enables efficient block reduction. However, block reduction often requires a large amount of shared memory space. The current implementation of LowerThreadAllreduce only enable static shared memory reduce buffer allocation and prevents the shared memory merging when another shared memory scope is defined as shared.dyn.

A_shared: "shared.dyn"
B_shared: "shared.dyn"
C_shared: "shared.dyn"
red: "shared" (can not be merged into the union shared memory pool)

This pull request addresses this issue by first collecting buffer allocations, and then determining the memory scope of the reduction buffer, allowing for memory space fusion in the following MergeSharedMemoryAllocations pass.

Hzfengsy commented 1 month ago

Please add a testcase for the enhancement