Open AndPotap opened 2 weeks ago
The latest jax version '0.4.30' changes how a user queries the device of an array. Now we have to use array.devices() instead of array.device(). Moreover, array.devices() outputs a Python set, so I convert it to a list and pass the first element.
array.devices()
array.device()
The latest jax version '0.4.30' changes how a user queries the device of an array. Now we have to use
array.devices()
instead ofarray.device()
. Moreover,array.devices()
outputs a Python set, so I convert it to a list and pass the first element.