tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[BYOC] Fix FuseOpsByPattern and RunCodegen for calling the same extern multiple times #441

Closed masahi closed 1 year ago

masahi commented 1 year ago

The two passes have a bug when calling the same extern function multiple times:

@R.function
def main(
    data: R.Tensor((16, 32, 32, 16), dtype="float16"),
    weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
    weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
    with R.dataflow():
        lv = fused_relax_nn_conv2d_tensorrt(data, weight1)
        gv = fused_relax_nn_conv2d_tensorrt(lv, weight2)
        R.output(gv)
    return gv

FuseOpsByPattern: The second call fails at builder_->GetContextIRModule()->Lookup(gvar), since the gvar is removed during the visit to the first call at https://github.com/tlc-pack/relax/blob/02f1ca72d0356fbc52ea11e81b03a0f1e3848a5b/src/relax/transform/fuse_ops.cc#L1008

RunCodegen: The second call to fused_relax_nn_conv2d_tensorrt doesn't get replaced with the extern func created during the first call, since the kCodegen attribute is removed at https://github.com/tlc-pack/relax/blob/02f1ca72d0356fbc52ea11e81b03a0f1e3848a5b/src/relax/transform/run_codegen.cc#L91 during the visit to the first call which in turn makes the condition at https://github.com/tlc-pack/relax/blob/02f1ca72d0356fbc52ea11e81b03a0f1e3848a5b/src/relax/transform/run_codegen.cc#L107-L108 false during the second visit.

Both issues have been fixed and a new test case is added for CUTLASS that demonstrates kernel sharing between multiple call_tir with the same callee kernel.

@vinx13 @yelite