apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.83k stars 3.48k forks source link

[Bugfix] Fix improper touched buffer assignment of Pass MergeSharedMemoryAllocations #17438

Open LeiWang1999 opened 1 month ago

LeiWang1999 commented 1 month ago

As discussed in issue #17375, the current rule for assigning touched buffers is not appropriate. Consider the following example:

code_block_0
for k in range(0, 10): # (the gen point of A_shared and B_shared will be injected into this for expression)
    for i in range(0, 10):
          A_shared <- A
    for i in range(0, 10):
          B_shared <- B
    code_block_1 (consume A_shared and B_shared)
code_block_2 (produce and consume C_shared)

This setup works by chance in simple GEMM scenarios. However, the correct approach should be

code_block_0
for k in range(0, 10): 
    for i in range(0, 10):
          A_shared <- A # (the gen point of A_shared should be bind into this BufferStore Node)
    for i in range(0, 10):
          B_shared <- B # (the gen point of B_shared be bind into this BufferStore Node)
    code_block_1 (consume A_shared and B_shared)
code_block_2 (produce and consume C_shared)

This approach works correctly even in more complex scenarios, such as batched GEMM, where the naive template would fail.

This pull request made a simple modification for MergeSharedMemory Pass to enable the right analysis, and always disable the naive naive shared memory buffer fuse if kernel with dynamic in StorageRewrite Pass