This PR adds a pass that merges neighboring calls to composite functions offloaded to the same external backend into one function. This is important for backends that want to receive as large subgraph as possible, for example TensorRT. It plays the same role as the MergeCompilerRegion pass in Relay BYOC does, and the algorithm follows the same idea described in https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830. As you can imagine, it is a tricky problem if branch diverge / merge are involved.
An interesting thing about my implementation is that this new pass is also making use of the same function-grouping mutator pass that FuseOps and FuseOpsByPattern use - the only difference between these passes is, again, the way to partition subexpressions into groups. Since the new pass is supposed to run after FuseOpsByPattern, we are essentially running one fusion pass on the output of another fusion pass (i.e., fusion of subgraphs, each of which is a fusion of ops). For now, the new pass is named MergeCompositeFunctions and the function-grouping mutator (OperatorFusor in fuse_ops.cc) is made reusable from outside as MakeGroupedFunctions function, but I welcome suggestions for better names for these functions.
A bug in OperatorFusor when a tuple-producing function is involved
There is a new grouped function that produces a tuple
An earlier binding depends on variables that are remapped when a call to the tuple-producing function is emitted later.
See the example below. The group B2 depends on the group A1 that produces a tuple. So the new grouped function A1 must be emitted before the one for B2. Depending on where the binding for the node in B2 is defined in the original order, the grouped function for B2 may be emitted before A1, consuming a variable in A1 that will become invalid after it is remapped to the result of TupleGetItem.
This is fixed by processing bindings in the order of the topological sort of the group dependency relations. cc @Hzfengsy
cc @sunggg @psrivas2 @mbaret @gigiblender @mikepapadim
A part of https://github.com/tlc-pack/relax/issues/364
This PR adds a pass that merges neighboring calls to composite functions offloaded to the same external backend into one function. This is important for backends that want to receive as large subgraph as possible, for example TensorRT. It plays the same role as the
MergeCompilerRegion
pass in Relay BYOC does, and the algorithm follows the same idea described in https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830. As you can imagine, it is a tricky problem if branch diverge / merge are involved.Before
After
An interesting thing about my implementation is that this new pass is also making use of the same function-grouping mutator pass that
FuseOps
andFuseOpsByPattern
use - the only difference between these passes is, again, the way to partition subexpressions into groups. Since the new pass is supposed to run afterFuseOpsByPattern
, we are essentially running one fusion pass on the output of another fusion pass (i.e., fusion of subgraphs, each of which is a fusion of ops). For now, the new pass is namedMergeCompositeFunctions
and the function-grouping mutator (OperatorFusor
infuse_ops.cc
) is made reusable from outside asMakeGroupedFunctions
function, but I welcome suggestions for better names for these functions.A bug in
OperatorFusor
when a tuple-producing function is involvedThis was found while I was working on the complicated example from https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830. Currently, bindings in
DataflowBlockNode
are processed in the original order, but this is incorrect ifSee the example below. The group B2 depends on the group A1 that produces a tuple. So the new grouped function A1 must be emitted before the one for B2. Depending on where the binding for the node in B2 is defined in the original order, the grouped function for B2 may be emitted before A1, consuming a variable in A1 that will become invalid after it is remapped to the result of TupleGetItem.
This is fixed by processing bindings in the order of the topological sort of the group dependency relations. cc @Hzfengsy
cc @sunggg @psrivas2 @mbaret @gigiblender @mikepapadim