Prior to this PR, CallTIR will always wrap the input args with Tuple when args is Expr. However, there are cases where args is already a Tuple. In such cases, the previous behavior is to wrap the given Tuple of another Tuple, which is undesired.
An example case is that when CallTIR calls into a PrimFunc that does not take input tensor but only has output tensor, the args for CallTIR is an empty Tuple, in which case the printer will print R.tuple() explicitly. Here if the parser parses the printed script back with CallTIR wrapping the empty Tuple again, the result is problematic, which will run into error in subsequent build. The demo script is as below:
@tvm.script.ir_module
class Module:
@T.prim_func
def full(T_full: T.Buffer[(T.int64(16), T.int64(32)), "float32"]):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(16), T.int64(32)):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads()
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = T.float32(1)
@R.function
def foo(dummy_param: R.Tensor(())) -> R.Tensor((16, 32), dtype="float32"):
gv = R.call_tir(full, R.tuple(), (16, 32), dtype="float32") ## <==== Cannot wrap `R.tuple()` by another Tuple.
return gv
This PR fixes a bug of CallTIR on Python side.
Prior to this PR, CallTIR will always wrap the input
args
with Tuple whenargs
is Expr. However, there are cases whereargs
is already a Tuple. In such cases, the previous behavior is to wrap the given Tuple of another Tuple, which is undesired.An example case is that when CallTIR calls into a PrimFunc that does not take input tensor but only has output tensor, the
args
for CallTIR is an empty Tuple, in which case the printer will printR.tuple()
explicitly. Here if the parser parses the printed script back with CallTIR wrapping the empty Tuple again, the result is problematic, which will run into error in subsequent build. The demo script is as below:cc @Hzfengsy