iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

[JAX] Replace uses of jax.xla_computation() with jax.jit().lower(). #55

Closed jpienaar closed 1 year ago

jpienaar commented 1 year ago

jax.xla_computation() is deprecated in favor of jax.jit(...).lower(...).