tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[BYOC] Fix MergeCompositeFunctions when some callees are not a composite function #385

Closed masahi closed 1 year ago

masahi commented 1 year ago

The PR https://github.com/tlc-pack/relax/pull/372 was making a wrong assumption that all ops in a module are composite functions. So if the output of FuseOpsByPattern is like

lv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d(data, weight)
conv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)  # <- not a composite function

MergeCompositeFunctions would try to create a grouped function for relu, which leads to an error since the grouped function for relu cannot be assigned the kCodegen attribute. But the function is still accidentally created due to the hacky create_single_binding_function_ flag I added in https://github.com/tlc-pack/relax/pull/372. This flag was only meant for creating a grouped function for composite functions even if there is only one composite function in the group. But as a side-effect, we ended up creating a grouped function for a single op as well.

This PR removes the create_single_binding_function_ flag, and instead adds attribute annotations to Group. By only annotating attributes for Group which corresponds to a composite function, we can skip creating a new function if the annotation is empty. The bug is fixed and overall the implementation is cleaner.

cc @sunggg