[Pallas] Add a cost estimator for Pallas/JAX functions.
Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: https://github.com/jax-ml/jax/issues/24539. This cost estimator uses only abstract evaluation which should work for all input sizes.
[Pallas] Add a cost estimator for Pallas/JAX functions.
Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: https://github.com/jax-ml/jax/issues/24539. This cost estimator uses only abstract evaluation which should work for all input sizes.