tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[BYOC] Add pattern-based partitioning pass #366

Closed masahi closed 1 year ago

masahi commented 1 year ago

A part of https://github.com/tlc-pack/relax/issues/364

This adds a new pass, FuseOpsByPattern, which applies pattern matching to each function in the given module, and groups matched expressions into a new function. The end result is similar to FuseOps, but fusion is driven completely by the provided patterns. The implementation also reuses OperatorFusor used by FuseOps to create grouped functions from partitioned groups, further illustrating the similarity between the two passes.

The new pass will serve the same role the MergeComposite pass plays in Relay BYOC - grouped functions are annotated with the "composite" attribute to denote what operations a given function consists of, and offloaded to external backends. But it can be also useful in non-BYOC settings, for example to support advanced fusion that the op-kind based one doesn't handle (fused MHA, conv2d / gemm + reduction fusion, etc).

@tqchen @Hzfengsy @sunggg @ganler @junrushao

masahi commented 1 year ago

@sunggg fused MHA = fused multi-head attention. It is basically batch_matmul -> softmax -> batch_matmul, and op-kind based fusion will never fuse such sequence of complicated ops. We need to fuse and schedule manually if we want to target fused MHA directly.

masahi commented 1 year ago

Thanks @masahi for the contribution!

One high level question: What is the restriction on the kind of patterns this pass supports. For example, would it support patterns where multiple internal nodes have external users?

For example, can we partition the graph below using conv2d->relu pattern?

conv1 = conv2d(x, w, ..)
relu1 = relu(conv1)
user_conv = gelu(conv1)
user_relu = gelu(relu1)

Note that the fused conv2d->relu function would now have to output a tuple of (conv1, relu1).

@psrivas2 ~I'd say it is a general question for Relax pattern matcher, rather than something specific to this PR~. The new pass uses the existing pattern matcher as it is. That said, I'll experiment with more complicated patterns like your example to see if it breaks.

UPDATE: On a second thought, assuming a conv2d -> relu pattern would match against your example, it is indeed a responsibility of the partitioner to output a functionally correct partitioned graph.

cc @ganler

ganler commented 1 year ago

Thanks @masahi for the contribution! One high level question: What is the restriction on the kind of patterns this pass supports. For example, would it support patterns where multiple internal nodes have external users? For example, can we partition the graph below using conv2d->relu pattern?

conv1 = conv2d(x, w, ..)
relu1 = relu(conv1)
user_conv = gelu(conv1)
user_relu = gelu(relu1)

Note that the fused conv2d->relu function would now have to output a tuple of (conv1, relu1).

@psrivas2 ~I'd say it is a general question for Relax pattern matcher, rather than something specific to this PR~. The new pass uses the existing pattern matcher as it is. That said, I'll experiment with more complicated patterns like your example to see if it breaks.

UPDATE: On a second thought, assuming a conv2d -> relu pattern would match against your example, it is indeed a responsibility of the partitioner to output a functionally correct partitioned graph.

cc @ganler

Whether conv2d->relu will be matched depends on the semantic of the pattern expression.

If you just say is_op("relu")(is_op("conv2d")), any graph with conv2d->relu will be matched regardless of dependencies -- because its semantic does not constrain the dependency stuff.

However, if you want to match conv2d->relu where conv2d is only used by relu, you can describe the expression as is_op("conv2d") >> is_op("relu") according to the previous design.

masahi commented 1 year ago

@psrivas2 Happy to report that the existing code in fuse_ops.cc can already create a fused function that outputs a tuple, and it creates a correct function for your example subgraph. However, it didn't have the correct logic for how to process such tuples (it tries to emit code like gelu(tuple) in your example).

I added necessary TupleGetItem and remapping of variables, so the new pass works on your example now. See the commit https://github.com/tlc-pack/relax/pull/366/commits/6a4d93a857b6cb937e6530e243377c587619f15b

cc @Hzfengsy

psrivas2 commented 1 year ago

@psrivas2 Happy to report that the existing code in fuse_ops.cc can already create a fused function that outputs a tuple, and it creates a correct function for your example subgraph. However, it didn't have the correct logic for how to process such tuples (it tries to emit code like gelu(tuple) in your example).

I added necessary TupleGetItem and remapping of variables, so the new pass works on your example now. See the commit 6a4d93a

awesome, thanks for addressing this!