In _which_np, we check if the variable is a jnp.DeviceArray however this will not work for ShardedDeviceArray and GlobalDeviceArray.
In Jax 0.4, this is changed to jax.Array as a unified array type, we could either add support for the other array types not currently supported. Or we update to use jax 0.4
In
_which_np
, we check if the variable is ajnp.DeviceArray
however this will not work forShardedDeviceArray
andGlobalDeviceArray
. In Jax 0.4, this is changed tojax.Array
as a unified array type, we could either add support for the other array types not currently supported. Or we update to use jax 0.4