Farama-Foundation / Jumpy

On-the-fly conversions between Jax and NumPy tensors
Apache License 2.0
47 stars 9 forks source link

Update to use `jax.Array` #24

Open pseudo-rnd-thoughts opened 1 year ago

pseudo-rnd-thoughts commented 1 year ago

In _which_np, we check if the variable is a jnp.DeviceArray however this will not work for ShardedDeviceArray and GlobalDeviceArray. In Jax 0.4, this is changed to jax.Array as a unified array type, we could either add support for the other array types not currently supported. Or we update to use jax 0.4