google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

DeprecationWarning `transform_with_state` because of `jax.xla` vs `jax.interpreters.xla` #634

Open joeryjoery opened 1 year ago

joeryjoery commented 1 year ago

Hi, when calling hk.transform_with_state internally there is an access to jax.xla, this is marked as deprecated in favor of jax.interpreters.xla. The problem is in checking if the provided function f to the haiku function is not jax transformed.

I.e., the misdoer is: check_not_jax_transformed (at least, that's how far I've looked into the code; there may exist more references).

Could this be updated? That will help silencing my test output :).