Hi, when calling hk.transform_with_state internally there is an access to jax.xla, this is marked as deprecated in favor of jax.interpreters.xla. The problem is in checking if the provided function f to the haiku function is not jax transformed.
I.e., the misdoer is: check_not_jax_transformed (at least, that's how far I've looked into the code; there may exist more references).
Could this be updated? That will help silencing my test output :).
Hi, when calling
hk.transform_with_state
internally there is an access tojax.xla
, this is marked as deprecated in favor ofjax.interpreters.xla
. The problem is in checking if the provided functionf
to thehaiku
function is notjax
transformed.I.e., the misdoer is:
check_not_jax_transformed
(at least, that's how far I've looked into the code; there may exist more references).Could this be updated? That will help silencing my test output :).