Closed Min-Li closed 2 years ago
Hi objax team,
When I ran your cifar10_advanced.py, it raises an error due to the deprecation of jax.api.
jax.api
Specifically, the following lines will trigger errors since jax has no attribute api anymore. https://github.com/google/objax/blob/d0aefeeb573fb366f2ee547f6869f2ca1b7ef284/objax/variable.py#L370-L373
api
The error can be easily mitigated by converting jax.api.device_put_sharded to jax.device_put_sharded
jax.api.device_put_sharded
jax.device_put_sharded
Hi objax team,
When I ran your cifar10_advanced.py, it raises an error due to the deprecation of
jax.api
.Specifically, the following lines will trigger errors since jax has no attribute
api
anymore. https://github.com/google/objax/blob/d0aefeeb573fb366f2ee547f6869f2ca1b7ef284/objax/variable.py#L370-L373The error can be easily mitigated by converting
jax.api.device_put_sharded
tojax.device_put_sharded