google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

[JAX] Replace uses of jax.xla_computation() with jax.jit().lower(). #1772

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

[JAX] Replace uses of jax.xla_computation() with jax.jit().lower().

jax.xla_computation() is deprecated in favor of jax.jit(...).lower(...).

The most common replacements are either jax.jit(...).lower(...).compiler_ir(dialect='hlo') or jax.jit(...).lower(...).cost_analysis().