This is a follow-up commit to
https://github.com/apache/tvm/pull/16637, which updated relax.transform.FuseOps to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that relax.transform.FuseOps produces well-formed Relax functions, these additional arguments can break some kernel implementations.
This commit implements a new transform
RemoveSymbolicExpressionsInSubroutine to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.
For example, consider the following Relax function:
The data tensor may be used to infer hidden_size, but cannot be used to infer batch_size or seq_len. The R.Shape parameter exists solely to define batch_size and seq_len, since all symbolic variables must be defined. However, neither batch_size nor seq_len are ever used outside of the expression batch_size * seq_len, and the value of batch_size * seq_len could be inferred from the shape of the data tensor.
This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the dummy_arg: R.Shape be entirely unused, so a later use of relax.transform.RemoveUnusedParameters() can remove the parameter altogether.
This is a follow-up commit to https://github.com/apache/tvm/pull/16637, which updated
relax.transform.FuseOps
to provide additional parameters defining symbolic variables required by the fused functions. While this ensures thatrelax.transform.FuseOps
produces well-formed Relax functions, these additional arguments can break some kernel implementations.This commit implements a new transform
RemoveSymbolicExpressionsInSubroutine
to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.For example, consider the following Relax function:
The
data
tensor may be used to inferhidden_size
, but cannot be used to inferbatch_size
orseq_len
. TheR.Shape
parameter exists solely to definebatch_size
andseq_len
, since all symbolic variables must be defined. However, neitherbatch_size
norseq_len
are ever used outside of the expressionbatch_size * seq_len
, and the value ofbatch_size * seq_len
could be inferred from the shape of thedata
tensor.This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the
dummy_arg: R.Shape
be entirely unused, so a later use ofrelax.transform.RemoveUnusedParameters()
can remove the parameter altogether.