jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

Remove _pjit_lower_cached cache. We can simplify the caching of jit as we have downstream caches and a cpp cache too. #25094

Closed copybara-service[bot] closed 6 hours ago

copybara-service[bot] commented 7 hours ago

Remove _pjit_lower_cached cache. We can simplify the caching of jit as we have downstream caches and a cpp cache too.

If you drop out of cpp cache, things are going to be slow anyways.