apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.42k stars 3.4k forks source link

[Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine #17080

Open Lunderberg opened 1 month ago

Lunderberg commented 1 month ago

This is a follow-up commit to https://github.com/apache/tvm/pull/16637, which updated relax.transform.FuseOps to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that relax.transform.FuseOps produces well-formed Relax functions, these additional arguments can break some kernel implementations.

This commit implements a new transform RemoveSymbolicExpressionsInSubroutine to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.

For example, consider the following Relax function:

@R.function
def func(
    data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):

    batch_size = T.int64()
    seq_len = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
    return output

The data tensor may be used to infer hidden_size, but cannot be used to infer batch_size or seq_len. The R.Shape parameter exists solely to define batch_size and seq_len, since all symbolic variables must be defined. However, neither batch_size nor seq_len are ever used outside of the expression batch_size * seq_len, and the value of batch_size * seq_len could be inferred from the shape of the data tensor.

This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the dummy_arg: R.Shape be entirely unused, so a later use of relax.transform.RemoveUnusedParameters() can remove the parameter altogether.

@R.function
def func(
    data: R.Tensor(["data_dim0", "hidden_size"]),
    weights: R.Tensor(["hidden_size", "intermediate_size"]),
    dummy_arg: R.Shape(["batch_size", "seq_len"]),
  ):

    data_dim0 = T.int64()
    intermediate_size = T.int64()
    hidden_size = T.int64()

    output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
    return output
Lunderberg commented 1 month ago

This transform is intended to be used in the implementation of https://github.com/apache/tvm/pull/16450, as recommended here.