google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Change to jax.interpreters.xla for JAX==0.4.14 #693

Closed kmheckel closed 1 year ago

kmheckel commented 1 year ago

Hi,

Just opening this issue to raise awareness for changes that will be needed for JAX 0.4.14. I would be happy to try and make some of the updates but this would be the first time I've ever tried submitting a pull request. Thanks!

For JAX 0.4.14 The following APIs have been removed after previous deprecation:

jax.ad: use jax.interpreters.ad.

jax.curry: use curry = lambda f: partial(partial, f).

jax.partial_eval: use jax.interpreters.partial_eval.

jax.pxla: use jax.interpreters.pxla.

jax.xla: use jax.interpreters.xla.

jax.ShapedArray: use jax.core.ShapedArray.

jax.interpreters.pxla.device_put: use jax.device_put().

jax.interpreters.pxla.make_sharded_device_array: use jax.make_array_from_single_device_arrays().

jax.interpreters.pxla.ShardedDeviceArray: use jax.Array.

jax.numpy.DeviceArray: use jax.Array.

tomhennigan commented 1 year ago

Hi @kmheckel , thank you for letting us know. I believe that at HEAD we have made all of these required changes.

In general since Haiku is developed inside Google's monorepo (and mirrored live to GitHub) these sort of renames are typically handled automatically for us.

Are you having issues with Haiku + JAX 0.4.14? Can you paste the error message if so?

kmheckel commented 1 year ago

@tomhennigan Sorry about the delay, dm-haiku 0.0.10 fixed my issue. When I had tried updating haiku right before hitting the error I had it was because PyPI only had 0.0.9 which didn't have the fix.

Thanks!

tomhennigan commented 1 year ago

Thanks for letting us know 😄