apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.58k stars 3.43k forks source link

[Bug] Tensorization breaks when TIR one dimension is a unit iterator #16566

Open patschmidt2 opened 6 months ago

patschmidt2 commented 6 months ago

Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first :smile_cat:

Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed.

Expected behavior

Tensorization works, replacing a part of the schedule with my custom instruction

Actual behavior

Tensorization throws an error about CompareBufferRegion buffer region min mismatch. I think that during tensorization the schedule is simplified to prune inner unit_iters, see this function here. The inner iteration variable is set to zero. However, if I have an intrinsic where one dimension is a unit iterator the same simplification is not applied, leading to an error as the parts of the schedule are no longer equivalent. I am aware that intrinsics with unit iters are not necessarily the most common use-case, but the equivalent feature existed in TE based scheduling, so it would be nice if it would still work with TIR.

Environment

Rocky Linux

Steps to reproduce

import tvm
from tvm import te
from tvm.script import tir as T

dim_I = 1
dim_K = 1024
dim_J = 512

inp_shape =  (dim_I, dim_K)
wght_shape = (dim_K, dim_J)
out_shape =  (dim_I, dim_J)

ins_dtype = "int8"
out_dtype = "int8"

inp = te.placeholder(inp_shape, dtype=ins_dtype, name="a_in")
wght = te.placeholder(wght_shape, dtype=ins_dtype, name="b_in")
rk = te.reduce_axis((0, dim_K), name="k")

res = te.compute(
    out_shape,
    lambda i, j: te.sum(
        inp[i, rk].astype(out_dtype) * wght[rk, j].astype(out_dtype),
        axis=[rk],
    ),
    name="res",
    tag="dense",
)

func = te.create_prim_func([inp, wght, res])
sch = tvm.tir.Schedule(func)

def get_intrin_gemm(
    dim_i: int,
    dim_k: int,
    dim_j: int,
):
    @T.prim_func
    def matmul_desc(a: T.handle, b:T.handle, c:T.handle, ) -> None:
        A = T.match_buffer(a, (dim_i, dim_k), "int8", offset_factor=1,)
        B = T.match_buffer(b, (dim_k, dim_j), "int8", offset_factor=1,)
        C = T.match_buffer(c, (dim_i, dim_j), "int8", offset_factor=1,)

        with T.block("root"):
            T.reads(C[0:dim_i, 0:dim_j], A[0:dim_i, 0:dim_k], B[0:dim_k, 0:dim_j])
            T.writes(C[0:dim_i, 0:dim_j])
            for i, k, j in T.grid(dim_i, dim_k, dim_j):
                with T.block(""):
                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                    C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], ins_dtype) * T.cast(B[vkk, vjj], ins_dtype)

    @T.prim_func
    def matmul_impl(a: T.handle, b:T.handle, c:T.handle, ) -> None:
        A = T.match_buffer(a, (dim_i, dim_k), "int8", offset_factor=1,)
        B = T.match_buffer(b, (dim_k, dim_j), "int8", offset_factor=1,)
        C = T.match_buffer(c, (dim_i, dim_j), "int8", offset_factor=1,)

        with T.block("root"):
            T.reads(A[0:dim_i, 0:dim_k], B[0:dim_k, 0:dim_j], C[0:dim_i, 0:dim_j],)
            T.writes(C[0:dim_i, 0:dim_j])
            T.evaluate(
                T.call_extern("computer_function_extern",
                            dtype="")
            )
    return matmul_desc, matmul_impl

desc, impl = get_intrin_gemm(dim_I, dim_K, dim_J)

res_block = sch.get_block("res")
i, j, k = sch.get_loops(res_block)
sch.reorder(i,k,j)
sch.decompose_reduction(res_block, i)

tvm.tir.TensorIntrin.register("matmul_intrin", desc, impl)
sch.tensorize(i, "matmul_intrin")

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

LeiWang1999 commented 6 months ago

looks like its sth relevant to #16560 , comparator simplify the lhs from [v0, v1]-> [0, v1], but the tensor desc still keeps [v0, v1], checkout this pr.

patschmidt2 commented 6 months ago

Thanks, that definitely goes in the right direction! I have ported this over to my local branch. I still get an error though:

Error message: The stmt tir.For#0 doesn't match the tensor intrin
The pattern attempting to be matched:
for j in range(512):
    k = T.int32()
    with T.block("res_update"):
        v_j_i = T.axis.spatial(512, j)
        v_k_i = T.axis.reduce(1024, k)
        res = T.Buffer((1, 512), "int8")
        v_i_o = T.int32()
        a_in = T.Buffer((1, 1024), "int8")
        b_in = T.Buffer((1024, 512), "int8")
        T.reads(res[v_i_o, v_j_i], a_in[v_i_o, v_k_i], b_in[v_k_i, v_j_i])
        T.writes(res[v_i_o, v_j_i])
        res[v_i_o, v_j_i] = res[v_i_o, v_j_i] + a_in[v_i_o, v_k_i] * b_in[v_k_i, v_j_i]
Does not match the tensorize description:
for j in range(512):
    k = T.int32()
    with T.block(""):
        vii = T.axis.spatial(1, 0)
        vjj = T.axis.spatial(512, j)
        vkk = T.axis.reduce(1024, k)
        C = T.Buffer((1, 512), "int8", offset_factor=1)
        A = T.Buffer((1, 1024), "int8", offset_factor=1)
        B = T.Buffer((1024, 512), "int8", offset_factor=1)
        T.reads(C[0, vjj], A[0, vkk], B[vkk, vjj])
        T.writes(C[0, vjj])
        C[0, vjj] = C[0, vjj] + A[0, vkk] * B[vkk, vjj]
CompareArray array size mismatch. lhs.size()=2 vs rhs.size()=3
BlockRealizeNode iter_values do not match: op->iter_values=[j, k] vs rhs->iter_values=[0, j, k]

It seems like the unit iterator is not actually deleted, but is still preserved.

lhutton1 commented 6 months ago

~I believe tensorize in TIR preserves unit iterators by default. It's possible to disable this functionality using sch.tensorize(..., preserve_unit_iters=False). Perhaps that could help?~ I just tried with the reproducer and it didn't have any effect.

LeiWang1999 commented 6 months ago

or maybe we can simplify the tensor desc as well.

patschmidt2 commented 6 months ago

It is not solved by setting preserve_unit_iters=False. The problem is not with the schedule, but that the simplifier is not simplifying the intrinsic description aggressively enough. If it would completely delete the unit IterVar it should be fine.

LeiWang1999 commented 6 months ago

hi @patschmidt2 , would you mind check out https://github.com/apache/tvm/pull/16560 ? hope this can address your problem.

patschmidt2 commented 6 months ago

@LeiWang1999 That is exactly what I tried ported to my local version of TVM. It does go in the right direction, I just need a way to tell the Simplifier to completely remove the unit IterVar. Currently, this is the intrinsic description after the simplify step:

with T.block("root"):
    C = T.Buffer((1, 512), "int8", offset_factor=1)
    A = T.Buffer((1, 1024), "int8", offset_factor=1)
    B = T.Buffer((1024, 512), "int8", offset_factor=1)
    T.reads(C[0, 0:512], A[0, 0:1024], B[0:1024, 0:512])
    T.writes(C[0, 0:512])
    for i, k, j in T.grid(1, 1024, 512):
        with T.block(""):
            vii = T.axis.spatial(1, 0)
            vjj, vkk = T.axis.remap("SR", [j, k])
            T.reads(C[0, vjj], A[0, vkk], B[vkk, vjj])
            T.writes(C[0, vjj])
            C[0, vjj] = C[0, vjj] + A[0, vkk] * B[vkk, vjj]

The variable vii is still in there, although it is not used. Still, the IRComparator complains that the two values are not equal. If you know of a way to tell the Simplify step to remove vii I think that should do the trick.

LeiWang1999 commented 6 months ago

It's an interesting case and I guess it's not simple to remove the unit iter axis automatically. Maybe one quick solution is to register a new tensor intrin specially for gemv case (with just two dims), as tvm tensorize currently does not support dynamic symbolic as well even though we fix this issue.

patschmidt2 commented 6 months ago

The reason I want an automated solution is that my hardware also supports intrinsics where one dimension can be a unit iterator. And then it gets messy quite fast, if any dimension can be a unit iterator I would have to write separate intrinics for every possible combination. Not that all of these combinations are a good idea, but they can show up during tuning, since I don't think there is a way to tell the sampling instructions that they should not produce unit iterators. I'm wondering if it would be a good idea to follow the approach that is also used in the blockize function. It analyzers the IterVars and splits them into inner and outer iters. From there a new block is constructed where the unit IterVars are simply not included again. Although this approach would impose one specific structure on defined intrinsics. And I don't understand the C++ API that well at this point.

patschmidt2 commented 6 months ago

I've noticed one more issue. I tried to just comment out the comparisons on the iter_vars of the BlockNodeand BlockRealizeNode to see if tensorization then works. This leads to this error:

CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=range(min=v_j_i, ext=1)Range(0x55bfb37cd9d0) vs rhs->region[i]=range(min=vjj, ext=1)Range(0x55bfb37cada0)
BlockNode write buffers do not match: op->writes=[res[0, v_j_i]] vs rhs->writes=[C[0, vjj]]

I don't understand why there is a Range() in there, that doesn't seem right to me. I have also tried to print the regions that are compared, but that also throws an error:

IndexError: Variable is not defined in the environment: v_j_i
IndexError: Variable is not defined in the environment: vjj

Which is weird, since these variables probably should exist somewhere. And vjj is not a unit iterator either.

LeiWang1999 commented 6 months ago

Thanks, that definitely goes in the right direction! I have ported this over to my local branch. I still get an error though:

Error message: The stmt tir.For#0 doesn't match the tensor intrin
The pattern attempting to be matched:
for j in range(512):
    k = T.int32()
    with T.block("res_update"):
        v_j_i = T.axis.spatial(512, j)
        v_k_i = T.axis.reduce(1024, k)
        res = T.Buffer((1, 512), "int8")
        v_i_o = T.int32()
        a_in = T.Buffer((1, 1024), "int8")
        b_in = T.Buffer((1024, 512), "int8")
        T.reads(res[v_i_o, v_j_i], a_in[v_i_o, v_k_i], b_in[v_k_i, v_j_i])
        T.writes(res[v_i_o, v_j_i])
        res[v_i_o, v_j_i] = res[v_i_o, v_j_i] + a_in[v_i_o, v_k_i] * b_in[v_k_i, v_j_i]
Does not match the tensorize description:
for j in range(512):
    k = T.int32()
    with T.block(""):
        vii = T.axis.spatial(1, 0)
        vjj = T.axis.spatial(512, j)
        vkk = T.axis.reduce(1024, k)
        C = T.Buffer((1, 512), "int8", offset_factor=1)
        A = T.Buffer((1, 1024), "int8", offset_factor=1)
        B = T.Buffer((1024, 512), "int8", offset_factor=1)
        T.reads(C[0, vjj], A[0, vkk], B[vkk, vjj])
        T.writes(C[0, vjj])
        C[0, vjj] = C[0, vjj] + A[0, vkk] * B[vkk, vjj]
CompareArray array size mismatch. lhs.size()=2 vs rhs.size()=3
BlockRealizeNode iter_values do not match: op->iter_values=[j, k] vs rhs->iter_values=[0, j, k]

It seems like the unit iterator is not actually deleted, but is still preserved.

btw, should we keep vii after doing simplification even though vii is not used in this block? @Hzfengsy

Hzfengsy commented 6 months ago

We only simplify the Expr, but keep the stmt. Some similar behavior is about the unit loops. We remove unit loops at a lowering pass instead of simplification

patschmidt2 commented 6 months ago

For curiosity: What is the difference between Stmt and Expr here?

patschmidt2 commented 4 months ago

@LeiWang1999 I tried using your commit but I still receive an error:

Error message: The stmt tir.For#0 doesn't match the tensor intrin
The pattern attempting to be matched:
for y_o in range(T.int64(32)):
    k_o = T.int64()
    with T.block("res_update"):
        v_y_o_i = T.axis.spatial(T.int64(32), y_o)
        v_k_o_i = T.axis.reduce(T.int64(64), k_o)
        res = T.Buffer((T.int64(1), T.int64(32)), "int8")
        v_x_o_o = T.int64()
        p0 = T.Buffer((T.int64(1), T.int64(64)), "int8")
        fused_constant = T.Buffer((64, 32), "int8")
        fused_constant_1 = T.Buffer((32,), "int32")
        T.reads(res[v_x_o_o, v_y_o_i], p0[v_x_o_o, v_k_o_i], fused_constant[v_k_o_i, v_y_o_i], fused_constant_1[v_y_o_i])
        T.writes(res[v_x_o_o, v_y_o_i])
        T.block_attr({"scale": T.float32(0.0007562367245554924)})
        res[v_x_o_o, v_y_o_i] = res[v_x_o_o, v_y_o_i] + (p0[v_x_o_o, v_k_o_i] * fused_constant[v_k_o_i, v_y_o_i] + T.Cast("int8", fused_constant_1[v_y_o_i]))
Does not match the tensorize description:
for j in range(T.int64(32)):
    k = T.int64()
    with T.block(""):
        vii = T.axis.spatial(T.int64(1), T.int64(0))
        vjj = T.axis.spatial(T.int64(32), j)
        vkk = T.axis.reduce(T.int64(64), k)
        C = T.Buffer((1, T.int64(32)), "int8", offset_factor=1)
        A = T.Buffer((1, T.int64(64)), "int8", offset_factor=1)
        B = T.Buffer((T.int64(64), T.int64(32)), "int8", offset_factor=1)
        Bias = T.Buffer((T.int64(32),), "int32", offset_factor=1)
        T.reads(C[T.int64(0), vjj], A[T.int64(0), vkk], B[vkk, vjj], Bias[vjj])
        T.writes(C[T.int64(0), vjj])
        C[T.int64(0), vjj] = C[T.int64(0), vjj] + (A[T.int64(0), vkk] * B[vkk, vjj] + T.Cast("int8", Bias[vjj]))
CompareArray array size mismatch. lhs.size()=2 vs rhs.size()=3
BlockRealizeNode iter_values do not match: op->iter_values=[y_o, k_o] vs rhs->iter_values=[T.int64(0), j, k]

So that is still the original error. I have found a way to change the DeriveBlockBinding function in order to keep all iter_values but that leads to this new error:

Error message: The stmt tir.For#0 doesn't match the tensor intrin
The pattern attempting to be matched:
for y_o in range(T.int64(32)):
    x_o = T.int64()
    k_o = T.int64()
    with T.block("res_update"):
        v_x_o_i = T.axis.spatial(T.int64(1), x_o)
        v_y_o_i = T.axis.spatial(T.int64(32), y_o)
        v_k_o_i = T.axis.reduce(T.int64(64), k_o)
        res = T.Buffer((T.int64(1), T.int64(32)), "int8")
        p0 = T.Buffer((T.int64(1), T.int64(64)), "int8")
        fused_constant = T.Buffer((64, 32), "int8")
        fused_constant_1 = T.Buffer((32,), "int32")
        T.reads(res[T.int64(0), v_y_o_i], p0[T.int64(0), v_k_o_i], fused_constant[v_k_o_i, v_y_o_i], fused_constant_1[v_y_o_i])
        T.writes(res[T.int64(0), v_y_o_i])
        T.block_attr({"scale": T.float32(0.00083669449668377638)})
        res[T.int64(0), v_y_o_i] = res[T.int64(0), v_y_o_i] + (p0[T.int64(0), v_k_o_i] * fused_constant[v_k_o_i, v_y_o_i] + T.Cast("int8", fused_constant_1[v_y_o_i]))
Does not match the tensorize description:
for j in range(T.int64(32)):
    k = T.int64()
    with T.block(""):
        vii = T.axis.spatial(T.int64(1), T.int64(0))
        vjj = T.axis.spatial(T.int64(32), j)
        vkk = T.axis.reduce(T.int64(64), k)
        C = T.Buffer((1, T.int64(32)), "int8", offset_factor=1)
        A = T.Buffer((1, T.int64(64)), "int8", offset_factor=1)
        B = T.Buffer((T.int64(64), T.int64(32)), "int8", offset_factor=1)
        Bias = T.Buffer((T.int64(32),), "int32", offset_factor=1)
        T.reads(C[T.int64(0), vjj], A[T.int64(0), vkk], B[vkk, vjj], Bias[vjj])
        T.writes(C[T.int64(0), vjj])
        C[T.int64(0), vjj] = C[T.int64(0), vjj] + (A[T.int64(0), vkk] * B[vkk, vjj] + T.Cast("int8", Bias[vjj]))
Expression mismatch: x_o vs T.int64(0)
BlockRealizeNode iter_values do not match: op->iter_values=[x_o, y_o, k_o] vs rhs->iter_values=[T.int64(0), j, k]

The variable is preserved, but the Simplifier applied to the intrin description replaces the unit loop with a constant. Is there a similar simplifier that I could apply to the Block? Because that seems to be the only difference remaining here.

patschmidt2 commented 4 months ago

@LeiWang1999 I've tried using the same Simplify function that was applied to intrin->desc on block_realize but that doesn't work as Block is a Stmt and not a PrimFunc. Is there a similar function I can apply to the block?

patschmidt2 commented 4 months ago

@LeiWang1999 Actually, I don't think simplifying the whole block is needed. Just changing op->iter_values in this function should be sufficient.

LeiWang1999 commented 4 months ago

yeah, I think for this issue modify the iter_values should help.

patschmidt2 commented 4 months ago

@LeiWang1999 Are you aware of anything to simplify the iter_values? I wasn't able to find anything in the documentation.