tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[BYOC] Add pass to merge composite functions to offload large subgraphs #372

Closed masahi closed 1 year ago

masahi commented 1 year ago

A part of https://github.com/tlc-pack/relax/issues/364

This PR adds a pass that merges neighboring calls to composite functions offloaded to the same external backend into one function. This is important for backends that want to receive as large subgraph as possible, for example TensorRT. It plays the same role as the MergeCompilerRegion pass in Relay BYOC does, and the algorithm follows the same idea described in https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830. As you can imagine, it is a tricky problem if branch diverge / merge are involved.

Before

@R.function
def main(...):
    with R.dataflow():
        lv: ... = fused_relax_nn_conv2d_relax_nn_relu(data, weight1)
        gv: ... = fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2)
        R.output(gv)
    return gv

@R.function
def fused_relax_nn_conv2d_relax_nn_relu(...):
    R.func_attr({"Composite": "dnnl.conv2d_relu", ...})
    ...

@R.function
def fused_relax_nn_conv2d_relax_nn_relu1(...):
    R.func_attr({"Composite": "dnnl.conv2d_relu", ...})
    ...

After

@R.function
def main(...):
    with R.dataflow():
        gv: ... = fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(data, weight1, weight2)
        R.output(gv)
    return gv

@R.function
def fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1(data, weight1, weight2):
    R.func_attr({"Codegen": "dnnl", "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1"})
    with R.dataflow():
        @R.function
        def lv(...):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})
            ...

        lv2: ... = lv(data, weight1)

        @R.function
        def lv11(...):
            R.func_attr({"Composite": "dnnl.conv2d_relu"})

        gv3: = lv11(lv2, weight2)
        R.output(gv3)
    return gv3

An interesting thing about my implementation is that this new pass is also making use of the same function-grouping mutator pass that FuseOps and FuseOpsByPattern use - the only difference between these passes is, again, the way to partition subexpressions into groups. Since the new pass is supposed to run after FuseOpsByPattern, we are essentially running one fusion pass on the output of another fusion pass (i.e., fusion of subgraphs, each of which is a fusion of ops). For now, the new pass is named MergeCompositeFunctions and the function-grouping mutator (OperatorFusor in fuse_ops.cc) is made reusable from outside as MakeGroupedFunctions function, but I welcome suggestions for better names for these functions.

A bug in OperatorFusor when a tuple-producing function is involved

This was found while I was working on the complicated example from https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830. Currently, bindings in DataflowBlockNode are processed in the original order, but this is incorrect if

See the example below. The group B2 depends on the group A1 that produces a tuple. So the new grouped function A1 must be emitted before the one for B2. Depending on where the binding for the node in B2 is defined in the original order, the grouped function for B2 may be emitted before A1, consuming a variable in A1 that will become invalid after it is remapped to the result of TupleGetItem.

This is fixed by processing bindings in the order of the topological sort of the group dependency relations. cc @Hzfengsy

スクリーンショット 2023-01-24 8 03 56

cc @sunggg @psrivas2 @mbaret @gigiblender @mikepapadim