tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Fix][Op] Fix CallTIR about wrapping args with Tuple #354

Closed MasterJH5574 closed 1 year ago

MasterJH5574 commented 1 year ago

This PR fixes a bug of CallTIR on Python side.

Prior to this PR, CallTIR will always wrap the input args with Tuple when args is Expr. However, there are cases where args is already a Tuple. In such cases, the previous behavior is to wrap the given Tuple of another Tuple, which is undesired.

An example case is that when CallTIR calls into a PrimFunc that does not take input tensor but only has output tensor, the args for CallTIR is an empty Tuple, in which case the printer will print R.tuple() explicitly. Here if the parser parses the printed script back with CallTIR wrapping the empty Tuple again, the result is problematic, which will run into error in subsequent build. The demo script is as below:

@tvm.script.ir_module
class Module:
    @T.prim_func
    def full(T_full: T.Buffer[(T.int64(16), T.int64(32)), "float32"]):
        T.func_attr({"tir.noalias": True})
        for i0, i1 in T.grid(T.int64(16), T.int64(32)):
            with T.block("T_full"):
                ax0, ax1 = T.axis.remap("SS", [i0, i1])
                T.reads()
                T.writes(T_full[ax0, ax1])
                T_full[ax0, ax1] = T.float32(1)

    @R.function
    def foo(dummy_param: R.Tensor(())) -> R.Tensor((16, 32), dtype="float32"):
        gv = R.call_tir(full, R.tuple(), (16, 32), dtype="float32")  ## <==== Cannot wrap `R.tuple()` by another Tuple.
        return gv

cc @Hzfengsy