This PR is the operator legalization pass, which transforms high-level operator calls to call_tirs of corresponding low-level TIR PrimFuncs.
The legalization pass provides customizability, which enables people to pass in a customized legalization map to override the default legalization method.
The legalization supports symbolic shape. (At this moment only pooling does not support symbolic shape, as TOPI pooling does not support. This needs to be fixed in followup PRs.)
For fast development, as a first step we put the pass on Python side, which is fine enough at this moment. Eventually, we will move the pass to C++ side, with the legalization functions registered per op in operator registry.
The following code shows how to use this pass:
# Define the pass input IRModule
@tvm.script.ir_module
class Module:
@R.function
def main(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
z: R.Tensor((2, 3), "float32") = R.add(x, y)
r: R.Tensor((2, 3), "float32") = R.multiply(y, z)
return r
# Define the customized legalization function for "relax.add"
def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr:
from tvm import topi
return bb.call_te(topi.add, call.args[1], call.args[0])
# Apply the pass with the customized function to the module.
mod = LegalizeOps({"relax.add": customize_legalize_add})(Module)
#################################################################
# The result IRModule (note that the first binding in "main" is customized to "(y, x)"):
@tvm.script.ir_module
class Module:
@R.function
def main(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32"):
z = R.call_tir(add, (y, x), (2, 3), dtype="float32")
r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32")
return r
@T.prim_func
def add(
A: T.Buffer[(2, 3), "float32"],
B: T.Buffer[(2, 3), "float32"],
T_add: T.Buffer[(2, 3), "float32"],
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(2, 3):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
@T.prim_func
def multiply(
A: T.Buffer[(2, 3), "float32"],
B: T.Buffer[(2, 3), "float32"],
T_multiply: T.Buffer[(2, 3), "float32"],
):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(2, 3):
with T.block("T_multiply"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1]
This PR is the operator legalization pass, which transforms high-level operator calls to
call_tir
s of corresponding low-level TIR PrimFuncs.The following code shows how to use this pass:
Co-authored-by: Chaofan Lin siriusneo@sjtu.edu.cn Co-authored-by: Yixin Dong ubospica@gmail.com Co-authored-by: Siyuan Feng Hzfengsy@sjtu.edu.cn