google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.07k stars 813 forks source link

Fix issues related to new behavior of JAX DeviceArray.copy() #1735

Closed copybara-service[bot] closed 2 years ago

copybara-service[bot] commented 2 years ago

Fix issues related to new behavior of JAX DeviceArray.copy()

In https://github.com/google/jax/pull/10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is to explicitly call np.asarray(device_array).