apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.42k stars 3.4k forks source link

[Bug] [MetaScheduler] [TensorIR] Stride values are not inferred within _impl() block. #15522

Closed cbalint13 closed 11 months ago

cbalint13 commented 11 months ago

Hi folks !

I am facing the following issue when using .access_ptr() within auto-tensorization using a custom ISA, the backend simply fail to infer the .stride values from the derived ancestor buffer through which the .access_ptr() is looking at the data buffer.

The tvm-ms-testcase.py.gz script is attached here. Using the main branch 20230804 @ git 60855346 hash.

Description

Here are the declarations for the block description and implementation:

@T.prim_func
def vec_u8_i8_s32_desc(
    A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
    B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, align=INT8_MACS, scope="global"),
    C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT8_MACS, scope="global"),
) -> None:
    with T.block("root"):
        T.reads(C[0:INT32_LANES], A[0:INT8_MACS], B[0:INT32_LANES, 0:INT8_MACS])
        T.writes(C[0:INT32_LANES])
        for i in T.serial(0, INT32_LANES):
            for k in T.serial(0, INT8_MACS):
                with T.block("update"):
                    vi, vk = T.axis.remap("SR", [i, k])
                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

@T.prim_func
def vec_u8_i8_s32_impl(
    A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
    B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, strides=[T.int32(), T.int32()], scope="global"),
    C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT32_LANES, scope="global"),
) -> None:
    with T.block("root"):
        T.block_attr({"pragma_import_llvm": VEC_MAC_impl()})
        T.reads(C[0:INT32_LANES], A[0:INT8_MACS], B[0:INT32_LANES, 0:INT8_MACS])
        T.writes(C[0:INT32_LANES])
        with T.block("update"):
            T.call_extern(
                f"VEC_MACC",
                C.access_ptr("w"),
                A.access_ptr("r"),
                B.access_ptr("r"),
|--------->>>   B.strides[0],   # BUG ?! (notwork)
                #8,             # MANUAL (cbalint)
                dtype="int32")

VEC_MACC_INTRIN = f"vec_macc"

TensorIntrin.register(
    VEC_MACC_INTRIN, vec_u8_i8_s32_desc, vec_u8_i8_s32_impl
)

Actual results:

In the context of T.call_extern( f"VEC_MACC", {...}):

In contrast, here are two examples from the repo, apparently doing such inference for strides: https://github.com/apache/tvm/blob/907b29e544bd8acaf38b0a9be1c36543a51cdbe6/python/tvm/tir/tensor_intrin/cuda.py#L175 https://github.com/apache/tvm/blob/907b29e544bd8acaf38b0a9be1c36543a51cdbe6/python/tvm/tir/tensor_intrin/rocm.py#L240

Desired results:

Properly infer the values of strides for such case.

Thank you ! ~Cristian.


cbalint13 commented 11 months ago

Tracked down the issue.

  1. LLVM lowered code, before this->Optimize(), shows that TIR infered stride is i64:

    %6 = call i32 @VEC_MACC(ptr %4, ptr %5, ptr @fused_constant, i64 16)
    %9 = call i32 @VEC_MACC(ptr %7, ptr %8, ptr @fused_constant, i64 16)
  2. But after this->Optimize() dso_local VEC_MACC is marked undefined due to i32 vs i64 mismatch:

    %2 = tail call i32 @VEC_MACC(ptr %0, ptr %1, ptr nonnull @fused_constant, i64 16)
    declare dso_local i32 @VEC_MACC(ptr noundef, ptr noundef, ptr noundef, i32 noundef) local_unnamed_addr #4
  3. Trying to use explicit T.int32() or T.int32(B.strides[0]) casts are ineffective, lowered llvm type remains i64.

  4. In contrast to MS-TIR, using the TOPI way of tensorization, strides=[te.var("ldw"), 1] lowers to i32.

  5. Anyway, changing VEC_MACC declaration to use i64 solves the whole problem:

    
    --- tvm-ms-testcase.py.old  2023-08-11 12:29:56.055634954 +0300
    +++ tvm-ms-testcase.py  2023-08-11 12:30:13.215511711 +0300
    @@ -37,7 +37,7 @@
    int32_t VEC_MACC(int32_t *output,
                   const uint8_t *data,
                   const int8_t *kernel,
    -                  const int32_t stride) {{
    +                  const int64_t stride) {{
    printf("data: \\n");
    for (int j = 0; j < {INT8_MACS}; ++j) {{
       printf(" %i", data[j]);
    @@ -92,7 +92,7 @@
    @T.prim_func
    def vec_u8_i8_s32_impl(
     A: T.Buffer((INT8_MACS,), "uint8", offset_factor=1, align=INT8_MACS, scope="global"),
    -    B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, strides=[T.int32(), T.int32()], align=INT8_MACS, scope="global"),
    +    B: T.Buffer((INT32_LANES, INT8_MACS), "int8", offset_factor=1, strides=[T.int64(), T.int64()]], align=INT8_MACS, scope="global"),
     C: T.Buffer((INT32_LANES,), "int32", offset_factor=1, align=INT32_LANES, scope="global"),
    ) -> None:
     with T.block("root"):


Issue done.

---

But there are still objects of confusion:

* The very confusing part (still) is that explicit casts ```T.int32()``` are not accounted in any ways.
* Also, in this specific application, due to C function invokation **exact type match is required**.
* Despite enabling logs (all) it was impossible [to see](https://github.com/apache/tvm/blob/624f8a73c7a7ded99bc6c00e59c468de6b9315e1/python/tvm/meta_schedule/runner/local_runner.py#L281) in local_runner the "O_VEC_MACC undefined" thing.