tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[BUG] Tensorize when loop extent = 1 #381

Open vinx13 opened 3 years ago

vinx13 commented 3 years ago

Tensorize currently doesn't work when axis of a buffer has extent = 1. See the example.

import tvm
from tvm import te, tir
from tvm.script import ty

@tvm.script.tir
def intrin_mma_desc(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32, 1), "float32",  scope="global", offset_factor=1)
    B = tir.match_buffer(b, (32, 1), "float32",  scope="global", offset_factor=1)
    C = tir.match_buffer(c, (32, 32), "float32", scope="global", offset_factor=1)
    with tir.block([32, 32, tir.reduce_axis(0, 1)], "root") as [vi, vj, vk]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        tir.bind(vk, 0)
        tir.reads([C[vi:vi+32, vj:vj+32], A[vi:vi+32,vk:vk+1], B[vj:vj+32,vk:vk+1]])
        tir.writes(C[vi:vi+32, vj:vj+32])
        for i, j, k in tir.grid(32, 32, 1):
            with tir.block([32, 32, tir.reduce_axis(0, 1)], "B") as [vii, vjj, vkk]:
                tir.bind(vii, vi + i)
                tir.bind(vjj, vj + j)
                tir.bind(vkk, vk)
                C[vii, vjj] = C[vii, vjj] + A[vii,vkk] * B[vjj,vkk]

@tvm.script.tir
def intrin_mma_impl(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32, 1), "float32",  scope="global", offset_factor=1)
    B = tir.match_buffer(b, (32, 1), "float32",  scope="global", offset_factor=1)
    C = tir.match_buffer(c, (32, 32), "float32", scope="global", offset_factor=1)
    with tir.block([32, 32, tir.reduce_axis(0, 1)], "root") as [vi, vj, vk]:
        tir.bind(vi, 0)
        tir.bind(vj, 0)
        tir.bind(vk, 0)
        tir.reads([C[vi:vi+32, vj:vj+32], A[vi:vi+32, vk:vk+1], B[vj:vj+32,vk:vk+1]])
        tir.writes(C[vi:vi+32, vj:vj+32])
        tir.evaluate(tir.tvm_mma_sync(C.data, C.elem_offset // 1024, A.data, A.elem_offset // 32, B.data, B.elem_offset // 32, dtype='handle'))

@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

def main():
    mod = tvm.script.create_module({'main': matmul})
    s = tir.Schedule(mod)
    C = s.get_block('C')
    i, j, k = s.get_axes(C)
    i0, i1 = s.split(i, factor=32)
    j0, j1 = s.split(j, factor=32)
    k0, k1 = s.split(k, factor=1)
    s.reorder(i0, j0, k0, i1, j1, k1)
    s.tensorize(i1, tir.TensorIntrin(intrin_mma_desc, intrin_mma_impl))

    print(tvm.script.asscript(s.mod['main']))

main()

The above code doesn't work because mismatch between loop and tensor intrinsic description. The loop k1 is eliminated from the block iter var (this is because of this).

If I remove this part of code, we still need to fix the patten matcher here https://github.com/Hzfengsy/tvm-tensorir/blob/main/src/tir/schedule/primitives/blockize_tensorize.cc#L73 because the original loop after blockize will be

block C(iter_var(vi, range(min=0, ext=128)), iter_var(vj, range(min=0, ext=128)), iter_var(vk, range(min=0, ext=128)){
  bind(vi, ((vio*32) + i0_inner))
  bind(vj, ((vjo*32) + i1_inner))
  bind(vk, vko)  # the inner loop var of extent 1 is still eliminated.
  reads([C[vi, vj], A[vi, vk], B[vj, vk]])
  writes([C[vi, vj]])
  C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
}

and as a result B.elem_offset is lowered to get_elem_offset(B[vjo * 32, 0] instead of get_elem_offset(B[vjo * 32, vko] because the detected binding of vk is incorrect.

The design question here is whether we should eliminated loop of extent 1 during blockize and tensorize.

junrushao commented 2 years ago

Is this issue addressed yet? If so let's close this issue

vinx13 commented 2 years ago

I'll double check this when upstreaming

Hzfengsy commented 2 years ago

cc @spectrometerHBH

Hzfengsy commented 2 years ago

I guess it is because of affine map