[JAX] Replace uses of isinstance(x, jax.numpy.ndarray) with
isinstance(x, (numpy.ndarray, jax.numpy.ndarray)) where
they were leading to test failures.
An upcoming change to JAX will make isinstance(x, jax.numpy.ndarray)
return true if and only if x is an instance of a JAX array type.
Previously isinstance(x, jax.numpy.ndarray) also returned true
for classic NumPy's numpy.ndarray objects as well. After the
upcoming change, it will return false for numpy.ndarray objects.
This change updates users of JAX who were depending on the
current behavior of the isinstance check to instead explicitly
check for numpy.ndarray instances as well.
These changes should have no effect on using jax.numpy.ndarray as
a type annotation. That does little and never has, although it is
possible that may change in the future. This change is strictly
about what jax.numpy.ndarray means to runtime isinstance checks.
[JAX] Replace uses of
isinstance(x, jax.numpy.ndarray)
withisinstance(x, (numpy.ndarray, jax.numpy.ndarray))
where they were leading to test failures.An upcoming change to JAX will make
isinstance(x, jax.numpy.ndarray)
return true if and only ifx
is an instance of a JAX array type.Previously
isinstance(x, jax.numpy.ndarray)
also returned true for classic NumPy'snumpy.ndarray
objects as well. After the upcoming change, it will return false fornumpy.ndarray
objects. This change updates users of JAX who were depending on the current behavior of theisinstance
check to instead explicitly check fornumpy.ndarray
instances as well.These changes should have no effect on using
jax.numpy.ndarray
as a type annotation. That does little and never has, although it is possible that may change in the future. This change is strictly about whatjax.numpy.ndarray
means to runtimeisinstance
checks.