google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

[JAX] Replace uses of jax.xla_computation() with jax.jit().lower(). #603

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

[JAX] Replace uses of jax.xla_computation() with jax.jit().lower().

jax.xla_computation() is deprecated in favor of jax.jit(...).lower(...).

The most common replacements are either jax.jit(...).lower(...).compiler_ir(dialect='hlo') or jax.jit(...).lower(...).cost_analysis().