tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[BUG][MetaSchedule] Cuda Tensorcore Integration Test Error #439

Closed zxybazh closed 2 years ago

zxybazh commented 3 years ago

While running test_integration_cuda_tensorcore.py, I got the following error.

@tvm.script.tir
class Module:
    def main(var_A: ty.handle, var_B: ty.handle, var_C: ty.handle) -> None:
        A = tir.match_buffer(var_A, [512, 512], dtype="float16", elem_offset=0, align=128, offset_factor=1)
        B = tir.match_buffer(var_B, [512, 512], dtype="float16", elem_offset=0, align=128, offset_factor=1)
        C = tir.match_buffer(var_C, [512, 512], elem_offset=0, align=128, offset_factor=1)
        # body
        with tir.block([], "root"):
            tir.reads([])
            tir.writes([])
            C_local = tir.alloc_buffer([512, 512], elem_offset=0, scope="local", align=128, offset_factor=1)
            A_shared = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="shared", align=128, offset_factor=1)
            B_shared = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="shared", align=128, offset_factor=1)
            A_shared_wmma_matrix_a = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="wmma.matrix_a", align=128, offset_factor=1)
            B_shared_wmma_matrix_b = tir.alloc_buffer([512, 512], dtype="float16", elem_offset=0, scope="wmma.matrix_b", align=128, offset_factor=1)
            C_local_wmma_accumulator = tir.alloc_buffer([512, 512], elem_offset=0, scope="wmma.accumulator", align=128, offset_factor=1)
            for i0_0_0_i1_0_0_fused in tir.thread_binding(0, 2, thread = "blockIdx.x"):
                for i0_0_1_i1_0_1_fused in tir.thread_binding(0, 4, thread = "vthread"):
                    for i0_0_2_i1_0_2_fused in tir.thread_binding(0, 16, thread = "threadIdx.x"):
                        for i0_0_3_init, i1_0_3_init in tir.grid(2, 4):
                            with tir.block([32, 32], "blockized_C_init") as [io_init, jo_init]:
                                tir.bind(io_init, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3_init))
                                tir.bind(jo_init, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3_init))
                                tir.reads([])
                                tir.writes([C_local_wmma_accumulator[(io_init*16):((io_init*16) + 16), (jo_init*16):((jo_init*16) + 16)]])
                                with tir.block([1, 1], "blockized_C_init") as [i_inito, j_inito]:
                                    tir.bind(i_inito, 0)
                                    tir.bind(j_inito, 0)
                                    tir.reads([])
                                    tir.writes([C_local_wmma_accumulator[(io_init*16):((io_init*16) + 16), (jo_init*16):((jo_init*16) + 16)]])
                                    tir.evaluate(tir.tvm_fill_fragment(C_local_wmma_accumulator.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), tir.float32(0), dtype="handle"))
                        for i2_0_0 in tir.serial(0, 2):
                            for ax0_ax1_fused_0 in tir.serial(0, 16384, annotation = {"loop_type":"lazy_cooperative_fetch"}):
                                for ax0_ax1_fused_1 in tir.vectorized(0, 4):
                                    with tir.block([512, 512], "B_shared") as [v0, v1]:
                                        tir.bind(v0, ((i2_0_0*256) + tir.floordiv(((ax0_ax1_fused_0*4) + ax0_ax1_fused_1), 256)))
                                        tir.bind(v1, ((i0_0_0_i1_0_0_fused*256) + tir.floormod(((ax0_ax1_fused_0*4) + ax0_ax1_fused_1), 256)))
                                        tir.reads([B[v0, v1]])
                                        tir.writes([B_shared[v0, v1]])
                                        B_shared[v0, v1] = B[v0, v1]
                            for ax0_ax1_fused_0_1 in tir.serial(0, 32768, annotation = {"loop_type":"lazy_cooperative_fetch"}):
                                for ax0_ax1_fused_1_1 in tir.vectorized(0, 4):
                                    with tir.block([512, 512], "A_shared") as [v0_1, v1_1]:
                                        tir.bind(v0_1, tir.floordiv(((ax0_ax1_fused_0_1*4) + ax0_ax1_fused_1_1), 256))
                                        tir.bind(v1_1, ((i2_0_0*256) + tir.floormod(((ax0_ax1_fused_0_1*4) + ax0_ax1_fused_1_1), 256)))
                                        tir.reads([A[v0_1, v1_1]])
                                        tir.writes([A_shared[v0_1, v1_1]])
                                        A_shared[v0_1, v1_1] = A[v0_1, v1_1]
                            for i2_0_1, i0_0_3, i1_0_3, i2_0_2, i0_0_4, i1_0_4 in tir.grid(8, 2, 4, 2, 1, 1):
                                with tir.block([32, 32], "blockized_B_shared_wmma.matrix_b") as [v0o, v1o]:
                                    tir.bind(v0o, (((i2_0_0*16) + (i2_0_1*2)) + i2_0_2))
                                    tir.bind(v1o, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3))
                                    tir.reads([B_shared[(v0o*16):((v0o*16) + 16), (v1o*16):((v1o*16) + 16)]])
                                    tir.writes([B_shared_wmma_matrix_b[(v0o*16):((v0o*16) + 16), (v1o*16):((v1o*16) + 16)]])
                                    tir.evaluate(tir.tvm_load_matrix_sync(B_shared_wmma_matrix_b.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(B_shared_wmma_matrix_b[0, 0], dtype="int32"), 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), B_shared.data, tir.get_elem_offset(B_shared[0, 0], dtype="int32"), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
                                with tir.block([32, 32], "blockized_A_shared_wmma.matrix_a") as [v0o_1, v1o_1]:
                                    tir.bind(v0o_1, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3))
                                    tir.bind(v1o_1, (((i2_0_0*16) + (i2_0_1*2)) + i2_0_2))
                                    tir.reads([A_shared[(v0o_1*16):((v0o_1*16) + 16), (v1o_1*16):((v1o_1*16) + 16)]])
                                    tir.writes([A_shared_wmma_matrix_a[(v0o_1*16):((v0o_1*16) + 16), (v1o_1*16):((v1o_1*16) + 16)]])
                                    tir.evaluate(tir.tvm_load_matrix_sync(A_shared_wmma_matrix_a.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(A_shared_wmma_matrix_a[0, 0], dtype="int32"), 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), A_shared.data, tir.get_elem_offset(A_shared[0, 0], dtype="int32"), 256, 1, dtype="handle"), 16, "row_major", dtype="handle"))
                                with tir.block([32, 32, tir.reduce_axis(0, 32)], "blockized_C_update") as [io, jo, ko]:
                                    tir.bind(io, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3))
                                    tir.bind(jo, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3))
                                    tir.bind(ko, (((i2_0_0*16) + (i2_0_1*2)) + i2_0_2))
                                    tir.reads([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)], A_shared_wmma_matrix_a[(io*16):((io*16) + 16), (ko*16):((ko*16) + 16)], B_shared_wmma_matrix_b[(ko*16):((ko*16) + 16), (jo*16):((jo*16) + 16)]])
                                    tir.writes([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)]])
                                    with tir.block([1, 1, tir.reduce_axis(0, 1)], "blockized_C") as [io_1, jo_1, ko_1]:
                                        tir.bind(io_1, 0)
                                        tir.bind(jo_1, 0)
                                        tir.bind(ko_1, 0)
                                        tir.reads([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)], A_shared_wmma_matrix_a[(io*16):((io*16) + 16), (ko*16):((ko*16) + 16)], B_shared_wmma_matrix_b[(ko*16):((ko*16) + 16), (jo*16):((jo*16) + 16)]])
                                        tir.writes([C_local_wmma_accumulator[(io*16):((io*16) + 16), (jo*16):((jo*16) + 16)]])
                                        tir.evaluate(tir.tvm_mma_sync(C_local_wmma_accumulator.data, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), A_shared_wmma_matrix_a.data, tir.floordiv(tir.get_elem_offset(A_shared_wmma_matrix_a[0, 0], dtype="int32"), 256), B_shared_wmma_matrix_b.data, tir.floordiv(tir.get_elem_offset(B_shared_wmma_matrix_b[0, 0], dtype="int32"), 256), C_local_wmma_accumulator.data, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), dtype="handle"))
                                with tir.block([32, 32], "blockized_C_local_wmma.accumulator") as [v0o_2, v1o_2]:
                                    tir.bind(v0o_2, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*16) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*2)) + i0_0_3))
                                    tir.bind(v1o_2, ((((i0_0_0_i1_0_0_fused*16) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*8)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*4)) + i1_0_3))
                                    tir.reads([C_local_wmma_accumulator[(v0o_2*16):((v0o_2*16) + 16), (v1o_2*16):((v1o_2*16) + 16)]])
                                    tir.writes([C_local[(v0o_2*16):((v0o_2*16) + 16), (v1o_2*16):((v1o_2*16) + 16)]])
                                    tir.evaluate(tir.tvm_store_matrix_sync(C_local_wmma_accumulator.data, 16, 16, 16, tir.floordiv(tir.get_elem_offset(C_local_wmma_accumulator[0, 0], dtype="int32"), 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), C_local.data, tir.get_elem_offset(C_local[0, 0], dtype="int32"), 256, 2, dtype="handle"), 16, "row_major", dtype="handle"))
                        for ax0, ax1 in tir.grid(32, 64):
                            with tir.block([512, 512], "C_local") as [v0_2, v1_2]:
                                tir.bind(v0_2, (((tir.floordiv(i0_0_1_i1_0_1_fused, 2)*256) + (tir.floordiv(i0_0_2_i1_0_2_fused, 2)*32)) + ax0))
                                tir.bind(v1_2, ((((i0_0_0_i1_0_0_fused*256) + (tir.floormod(i0_0_1_i1_0_1_fused, 2)*128)) + (tir.floormod(i0_0_2_i1_0_2_fused, 2)*64)) + ax1))
                                tir.reads([C_local[v0_2, v1_2]])
                                tir.writes([C[v0_2, v1_2]])
                                C[v0_2, v1_2] = C_local[v0_2, v1_2]

Traceback (most recent call last):
  File "test_integration_cuda_tensorcore.py", line 229, in <module>
    test_integration_conv2d_nchwc()
  File "test_integration_cuda_tensorcore.py", line 224, in test_integration_conv2d_nchwc
    schedule(sch)
  File "test_integration_cuda_tensorcore.py", line 207, in schedule
    fused = sch.fuse(*sch.get_loops(w_read)[-6:])
  File "/home/zxybazh/tvm-tensorir/python/tvm/tir/schedule/schedule.py", line 412, in fuse
    return _ffi_api_schedule.ScheduleFuse(self, loops)  # type: ignore # pylint: disable=no-member
  File "/home/zxybazh/tvm-tensorir/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm.tir.schedule.schedule.ScheduleError: ScheduleError: An error occurred in the schedule primitive 'fuse'.
The IR is:
@tvm.script.tir
class Module:
    def main(var_X: ty.handle, var_W: ty.handle, var_conv2d_nchwc: ty.handle) -> None:
        X = tir.match_buffer(var_X, [1, 6, 98, 98, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
        W = tir.match_buffer(var_W, [12, 6, 3, 3, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1)
        conv2d_nchwc = tir.match_buffer(var_conv2d_nchwc, [1, 12, 96, 96, 16], elem_offset=0, align=128, offset_factor=1)
        # body
        with tir.block([], "root"):
            tir.reads([])
            tir.writes([])
            conv2d_nchwc_local = tir.alloc_buffer([1, 12, 96, 96, 16], elem_offset=0, scope="local", align=128, offset_factor=1)
            W_shared = tir.alloc_buffer([12, 6, 3, 3, 16, 16], dtype="float16", elem_offset=0, scope="shared", align=128, offset_factor=1)
            for i0 in tir.serial(0, 1):
                for i1_0_i2_0_i3_0_0_i4_0_0_fused in tir.thread_binding(0, 4, thread = "blockIdx.x"):
                    for i1_1_i2_1_i3_0_1_i4_0_1_fused in tir.thread_binding(0, 2, thread = "vthread"):
                        for i1_2_i2_2_i3_0_2_i4_0_2_fused in tir.thread_binding(0, 12, thread = "threadIdx.x"):
                            for i5_0_0, i6_0, i7_0 in tir.grid(2, 1, 3):
                                for ax0, ax1, ax2, ax4, ax5 in tir.grid(12, 3, 3, 16, 16):
                                    with tir.block([12, 6, 3, 3, 16, 16], "W_shared") as [v0, v1, v2, v3, v4, v5]:
                                        tir.bind(v0, ax0)
                                        tir.bind(v1, ((i5_0_0*3) + ax1))
                                        tir.bind(v2, ax2)
                                        tir.bind(v3, i7_0)
                                        tir.bind(v4, ax4)
                                        tir.bind(v5, ax5)
                                        tir.reads([W[v0, v1, v2, v3, v4, v5]])
                                        tir.writes([W_shared[v0, v1, v2, v3, v4, v5]])
                                        W_shared[v0, v1, v2, v3, v4, v5] = W[v0, v1, v2, v3, v4, v5]
                                for i5_0_1, i6_1, i7_1, i1_3, i2_3, i3_0_3, i4_0_3, i5_0_2, i6_2, i7_2, i1_4, i2_4, i3_0_4, i4_0_4, i3_1, i4_1, i5_1 in tir.grid(1, 3, 1, 1, 2, 1, 1, 3, 1, 1, 6, 1, 6, 1, 16, 16, 16):
                                    with tir.block([1, 12, 96, 96, 16, tir.reduce_axis(0, 96), tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "conv2d_nchwc") as [n, c0, h, w, c1, rc, rh, rw]:
                                        tir.bind(n, 0)
                                        tir.bind(c0, ((i1_1_i2_1_i3_0_1_i4_0_1_fused*6) + i1_4))
                                        tir.bind(h, (((i1_0_i2_0_i3_0_0_i4_0_0_fused*24) + (i1_2_i2_2_i3_0_2_i4_0_2_fused*2)) + i2_3))
                                        tir.bind(w, ((i3_0_4*16) + i3_1))
                                        tir.bind(c1, i4_1)
                                        tir.bind(rc, (((i5_0_0*48) + (i5_0_2*16)) + i5_1))
                                        tir.bind(rh, i6_1)
                                        tir.bind(rw, i7_0)
                                        tir.reads([conv2d_nchwc_local[n, c0, h, w, c1], X[n, tir.floordiv(rc, 16), (h + rh), (w + rw), tir.floormod(rc, 16)], W_shared[c0, tir.floordiv(rc, 16), rh, rw, tir.floormod(rc, 16), c1]])
                                        tir.writes([conv2d_nchwc_local[n, c0, h, w, c1]])
                                        with tir.init():
                                            conv2d_nchwc_local[n, c0, h, w, c1] = tir.float32(0)
                                        conv2d_nchwc_local[n, c0, h, w, c1] = (conv2d_nchwc_local[n, c0, h, w, c1] + (tir.cast(X[n, tir.floordiv(rc, 16), (h + rh), (w + rw), tir.floormod(rc, 16)], "float32")*tir.cast(W_shared[c0, tir.floordiv(rc, 16), rh, rw, tir.floormod(rc, 16), c1], "float32")))
                            for ax1_1, ax2_1, ax3, ax4_1 in tir.grid(6, 2, 96, 16):
                                with tir.block([1, 12, 96, 96, 16], "conv2d_nchwc_local") as [v0_1, v1_1, v2_1, v3_1, v4_1]:
                                    tir.bind(v0_1, 0)
                                    tir.bind(v1_1, ((i1_1_i2_1_i3_0_1_i4_0_1_fused*6) + ax1_1))
                                    tir.bind(v2_1, (((i1_0_i2_0_i3_0_0_i4_0_0_fused*24) + (i1_2_i2_2_i3_0_2_i4_0_2_fused*2)) + ax2_1))
                                    tir.bind(v3_1, ax3)
                                    tir.bind(v4_1, ax4_1)
                                    tir.reads([conv2d_nchwc_local[v0_1, v1_1, v2_1, v3_1, v4_1]])
                                    tir.writes([conv2d_nchwc[v0_1, v1_1, v2_1, v3_1, v4_1]])
                                    conv2d_nchwc[v0_1, v1_1, v2_1, v3_1, v4_1] = conv2d_nchwc_local[v0_1, v1_1, v2_1, v3_1, v4_1]

Regions of interest:
tir.For#0
for (i7_0, 0, 3) {
  for (ax0, 0, 12) {
    for (ax1, 0, 3) {
      for (ax2, 0, 3) {
        for (ax4, 0, 16) {
          for (ax5, 0, 16) {
            block W_shared(iter_var(v0, range(min=0, ext=12)), iter_var(v1, range(min=0, ext=6)), iter_var(v2, range(min=0, ext=3)), iter_var(v3, range(min=0, ext=3)), iter_var(v4, range(min=0, ext=16)), iter_var(v5, range(min=0, ext=16))) {
              bind(v0, ax0)
              bind(v1, ((i5_0_0*3) + ax1))
              bind(v2, ax2)
              bind(v3, i7_0)
              bind(v4, ax4)
              bind(v5, ax5)
              reads([W[v0, v1, v2, v3, v4, v5]])
              writes([W_shared[v0, v1, v2, v3, v4, v5]])
              W_shared[v0, v1, v2, v3, v4, v5] = W[v0, v1, v2, v3, v4, v5]
            }
          }
        }
      }
    }
  }
  for (i5_0_1, 0, 1) {
    for (i6_1, 0, 3) {
      for (i7_1, 0, 1) {
        for (i1_3, 0, 1) {
          for (i2_3, 0, 2) {
            for (i3_0_3, 0, 1) {
              for (i4_0_3, 0, 1) {
                for (i5_0_2, 0, 3) {
                  for (i6_2, 0, 1) {
                    for (i7_2, 0, 1) {
                      for (i1_4, 0, 6) {
                        for (i2_4, 0, 1) {
                          for (i3_0_4, 0, 6) {
                            for (i4_0_4, 0, 1) {
                              for (i3_1, 0, 16) {
                                for (i4_1, 0, 16) {
                                  for (i5_1, 0, 16) {
                                    block conv2d_nchwc(iter_var(n, range(min=0, ext=1)), iter_var(c0, range(min=0, ext=12)), iter_var(h, range(min=0, ext=96)), iter_var(w, range(min=0, ext=96)), iter_var(c1, range(min=0, ext=16)), iter_var(rc, range(min=0, ext=96)), iter_var(rh, range(min=0, ext=3)), iter_var(rw, range(min=0, ext=3))) {
                                      bind(n, 0)
                                      bind(c0, ((i1_1_i2_1_i3_0_1_i4_0_1_fused*6) + i1_4))
                                      bind(h, (((i1_0_i2_0_i3_0_0_i4_0_0_fused*24) + (i1_2_i2_2_i3_0_2_i4_0_2_fused*2)) + i2_3))
                                      bind(w, ((i3_0_4*16) + i3_1))
                                      bind(c1, i4_1)
                                      bind(rc, (((i5_0_0*48) + (i5_0_2*16)) + i5_1))
                                      bind(rh, i6_1)
                                      bind(rw, i7_0)
                                      reads([conv2d_nchwc_local[n, c0, h, w, c1], X[n, floordiv(rc, 16), (h + rh), (w + rw), floormod(rc, 16)], W_shared[c0, floordiv(rc, 16), rh, rw, floormod(rc, 16), c1]])
                                      writes([conv2d_nchwc_local[n, c0, h, w, c1]])
                                      with init() {
                                        conv2d_nchwc_local[n, c0, h, w, c1] = 0f
                                      }
                                      conv2d_nchwc_local[n, c0, h, w, c1] = (conv2d_nchwc_local[n, c0, h, w, c1] + (float32(X[n, floordiv(rc, 16), (h + rh), (w + rw), floormod(rc, 16)])*float32(W_shared[c0, floordiv(rc, 16), rh, rw, floormod(rc, 16), c1])))
                                    }
                                  }
                                }
                              }
                            }
                          }
                        }
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }
}
tir.For#1
for (ax0, 0, 12) {
  for (ax1, 0, 3) {
    for (ax2, 0, 3) {
      for (ax4, 0, 16) {
        for (ax5, 0, 16) {
          block W_shared(iter_var(v0, range(min=0, ext=12)), iter_var(v1, range(min=0, ext=6)), iter_var(v2, range(min=0, ext=3)), iter_var(v3, range(min=0, ext=3)), iter_var(v4, range(min=0, ext=16)), iter_var(v5, range(min=0, ext=16))) {
            bind(v0, ax0)
            bind(v1, ((i5_0_0*3) + ax1))
            bind(v2, ax2)
            bind(v3, i7_0)
            bind(v4, ax4)
            bind(v5, ax5)
            reads([W[v0, v1, v2, v3, v4, v5]])
            writes([W_shared[v0, v1, v2, v3, v4, v5]])
            W_shared[v0, v1, v2, v3, v4, v5] = W[v0, v1, v2, v3, v4, v5]
          }
        }
      }
    }
  }
}

Error message: The loops can't be fused because the inner loop tir.For#1 is not the only child of outer loop tir.For#0.
junrushao commented 2 years ago

Won't fix