jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

[Pallas] Add a cost estimator for Pallas/JAX functions. #24809

Closed copybara-service[bot] closed 1 day ago

copybara-service[bot] commented 4 days ago

[Pallas] Add a cost estimator for Pallas/JAX functions.

Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: https://github.com/jax-ml/jax/issues/24539. This cost estimator uses only abstract evaluation which should work for all input sizes.