apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.41k stars 3.4k forks source link

[TIR] Enhance Lower cross thread Pass #17133

Open LeiWang1999 opened 3 days ago

LeiWang1999 commented 3 days ago

We currently only support lower cross thread with several constrains. For example, the lower_cross_thread only apples when the thread binding reduced axis is the innermost loop, and the block must have an init block. This can be a limiting for some cases.

For example, when tensorizing the reduction block (e.g., dp4a or mma), it becomes difficult to tensorize the init statement as well:

with T.block("block"):
    vi = T.axis.spatial(2, i_0 * 16 + i_1)
    vk = T.axis.reduce(32, k_0 * 64 + k_1)
    T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
    T.reads(A[vi, vk])
    T.writes(B[vi])
    with T.init():
        B[vi] = T.float32(0)
    B[vi] = B[vi] + A[vi, vk]

Moreover, certain cases, like small gemm, prefer block reduction in shared memory to enhance parallelization to better utilize the hardware resources.

This pull request improves the lower_cross_thread pass, it can now handle the thread block reduce lowering with separate init and reduce blocks, and removes the constrain that the reduced axis is the innermost loop to support TensorCore with block reduction.

relevant test cases can be found at tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py.

Please CC @MasterJH5574 .