google / objax

Apache License 2.0
768 stars 77 forks source link

Error due to the deprecation of jax.api #228

Closed Min-Li closed 2 years ago

Min-Li commented 2 years ago

Hi objax team,

When I ran your cifar10_advanced.py, it raises an error due to the deprecation of jax.api.

Specifically, the following lines will trigger errors since jax has no attribute api anymore. https://github.com/google/objax/blob/d0aefeeb573fb366f2ee547f6869f2ca1b7ef284/objax/variable.py#L370-L373

The error can be easily mitigated by converting jax.api.device_put_sharded to jax.device_put_sharded