rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

JAX v0.4.35 - xla_client.register_custom_call_target has been deprecated #17

Closed adam-hartshorne closed 1 week ago

adam-hartshorne commented 1 week ago

FYI, the semi-public API jax.lib.xla_client.register_custom_call_target has been deprecated in the latest JAX release, and looking at the torch2jax is used.

https://github.com/jax-ml/jax/releases/tag/jax-v0.4.35

rdyro commented 1 week ago

Thanks for the update! This is probably now the last wake-up call, I'll migrate to the new FFI interface ASAP.

Thankfully, the register_custom_call_target still works in tests, so no breaking changes yet.