tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

Allow keeping bound constants in OperatorFusor (rather than lifting to parameters) #397

Closed masahi closed 1 year ago

masahi commented 1 year ago

While working on TensorRT BYOC, I found that the TensorRT runtime requires the weight tensor to be passed as a constant at compile time. So we need to use BindParams before partitioning, but currently it results in the following IR:

@tvm.script.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        # block 0
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu(data, metadata["relax.expr.Constant"][0])
            R.output(gv)
        return gv

    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu(data1: R.Tensor((1, 64, 56, 56), dtype="float32"), param_0: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
      ...

Instead, what we need for BYOC is:

@tvm.script.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        # block 0
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = fused_relax_nn_conv2d_relax_nn_relu(data)
            R.output(gv)
        return gv

    @R.function
    def fused_relax_nn_conv2d_relax_nn_relu(data1: R.Tensor((1, 64, 56, 56), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        # function attr dict
        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
        # block 0
        with R.dataflow():
            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(data1, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="")
            gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
            R.output(gv1)

This PR adds an option to OperatorFusor, to allow keeping bound constants in the original position rather than lifting them to parameters. This is used when OperatorFusor is used by the FuseOpsByPattern pass.

@Hzfengsy @tqchen