google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.01k stars 813 forks source link

Fix issues related to new behavior of JAX DeviceArray.copy() #1734

Open copybara-service[bot] opened 2 years ago

copybara-service[bot] commented 2 years ago

Fix issues related to new behavior of JAX DeviceArray.copy()

In https://github.com/google/jax/pull/10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array).