Closed masahi closed 1 year ago
Isn't codegen interested in every op in the composite function? I'm thinking TensorRT-style BYOC in my mind. Is this useful for cutlass-style BYOC?
Yes, this is useful for library-BYOC like DNNL or cutlass, where one composite function contains multiple ops. For DNNL, if we want to offload a composite function for conv2d -> relu
, we need to tell DNNL that (1) this subgraph consists of conv2d + relu and (2) conv2d has such-and-such attributes. (1) is done by passing the name of pattern, and (2) is done by sending attributes for the conv2d call node. So we only need to examine the attributes of the "anchor node", which is typically nested in the graph of CallNodes
in a composite function.
To find such nested CallNode
, we have previously used GetOpInFunction
(GetRootCall
in Relay BYOC), but since the same information should be readily available in the matching results, I felt that having to do another simple "searching" in the codegen is redundant.
That said, I do have some concern for the usefulness of this matching result. For a simple composite function like conv2d -> bias -> relu
, the current solution using GetOpInFunction
is simple and sufficient. And for a more complicated subgraph like attention, matching results in the form of Map<String, Array<Call>>
might not be too useful - the best solution might be simply running another pattern matching in the codegen side, which enables extracting any CallNode
or arguments to it directly from the map of pattern and matching expr (since the call site has all individual patterns).
Any thought? cc @vinx13 @yelite
TVMScript roundtribility is a concern here. Ultimately running another pattern match during codegen should provide any information we need. So maybe we can use GetOpInFunction
along with pattern matching to get the nodes, that can probably avoid some complex options in the attributes.
I also think another round of pattern match cannot be avoided if we want to do structural query during codegen since we lift the subgraphs to a new function the original pattern match result get invalidated anyways
ok, I'm closing this for now, we can revisit this if a good use case arises. The clean up on the cutlass codegen is extracted to https://github.com/tlc-pack/relax/pull/426
Introduce
attr::kMatchedCallNodes
attribute, a map between an op name and a list of pattern-matched call nodes whose callee is identified by the key op name.This is useful to find a call node of interest (usually an anchor op) in a composite function. So far we are using
GetOpInFunction
in codegen for that purpose, but if we can simply look up matched call nodes, that function is unnecessary.If there are multiple call nodes with the same op name in one match (e.g. attention), the codegen is responsible for figuring out which of them is needed in a given situation.
cc @vinx13 @yelite