tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[BYOC] Attach matched call nodes as attributes of composite functions #417

Closed masahi closed 1 year ago

masahi commented 1 year ago

Introduce attr::kMatchedCallNodes attribute, a map between an op name and a list of pattern-matched call nodes whose callee is identified by the key op name.

This is useful to find a call node of interest (usually an anchor op) in a composite function. So far we are using GetOpInFunction in codegen for that purpose, but if we can simply look up matched call nodes, that function is unnecessary.

If there are multiple call nodes with the same op name in one match (e.g. attention), the codegen is responsible for figuring out which of them is needed in a given situation.

cc @vinx13 @yelite

masahi commented 1 year ago

Isn't codegen interested in every op in the composite function? I'm thinking TensorRT-style BYOC in my mind. Is this useful for cutlass-style BYOC?

Yes, this is useful for library-BYOC like DNNL or cutlass, where one composite function contains multiple ops. For DNNL, if we want to offload a composite function for conv2d -> relu, we need to tell DNNL that (1) this subgraph consists of conv2d + relu and (2) conv2d has such-and-such attributes. (1) is done by passing the name of pattern, and (2) is done by sending attributes for the conv2d call node. So we only need to examine the attributes of the "anchor node", which is typically nested in the graph of CallNodes in a composite function.

To find such nested CallNode, we have previously used GetOpInFunction (GetRootCall in Relay BYOC), but since the same information should be readily available in the matching results, I felt that having to do another simple "searching" in the codegen is redundant.

That said, I do have some concern for the usefulness of this matching result. For a simple composite function like conv2d -> bias -> relu, the current solution using GetOpInFunction is simple and sufficient. And for a more complicated subgraph like attention, matching results in the form of Map<String, Array<Call>> might not be too useful - the best solution might be simply running another pattern matching in the codegen side, which enables extracting any CallNode or arguments to it directly from the map of pattern and matching expr (since the call site has all individual patterns).

Any thought? cc @vinx13 @yelite

vinx13 commented 1 year ago

TVMScript roundtribility is a concern here. Ultimately running another pattern match during codegen should provide any information we need. So maybe we can use GetOpInFunction along with pattern matching to get the nodes, that can probably avoid some complex options in the attributes.

I also think another round of pattern match cannot be avoided if we want to do structural query during codegen since we lift the subgraphs to a new function the original pattern match result get invalidated anyways

masahi commented 1 year ago

ok, I'm closing this for now, we can revisit this if a good use case arises. The clean up on the cutlass codegen is extracted to https://github.com/tlc-pack/relax/pull/426