tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Pass] Operator Legalization #425

Closed MasterJH5574 closed 1 year ago

MasterJH5574 commented 1 year ago

This PR is the operator legalization pass, which transforms high-level operator calls to call_tirs of corresponding low-level TIR PrimFuncs.

The following code shows how to use this pass:

# Define the pass input IRModule
@tvm.script.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        z: R.Tensor((2, 3), "float32") = R.add(x, y)
        r: R.Tensor((2, 3), "float32") = R.multiply(y, z)
        return r

# Define the customized legalization function for "relax.add"
def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr:
    from tvm import topi
    return bb.call_te(topi.add, call.args[1], call.args[0])

# Apply the pass with the customized function to the module.
mod = LegalizeOps({"relax.add": customize_legalize_add})(Module)

#################################################################
# The result IRModule (note that the first binding in "main" is customized to "(y, x)"):
@tvm.script.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        z = R.call_tir(add, (y, x), (2, 3), dtype="float32")
        r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32")
        return r

    @T.prim_func
    def add(
        A: T.Buffer[(2, 3), "float32"],
        B: T.Buffer[(2, 3), "float32"],
        T_add: T.Buffer[(2, 3), "float32"],
    ):
        T.func_attr({"tir.noalias": True})
        for ax0, ax1 in T.grid(2, 3):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func
    def multiply(
        A: T.Buffer[(2, 3), "float32"],
        B: T.Buffer[(2, 3), "float32"],
        T_multiply: T.Buffer[(2, 3), "float32"],
    ):
        T.func_attr({"tir.noalias": True})
        for ax0, ax1 in T.grid(2, 3):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1]

Co-authored-by: Chaofan Lin siriusneo@sjtu.edu.cn Co-authored-by: Yixin Dong ubospica@gmail.com Co-authored-by: Siyuan Feng Hzfengsy@sjtu.edu.cn