Closed copybara-service[bot] closed 2 years ago
[NumPy] Fix test failure in distrax under NumPy 1.23, attempt 2.
The semantics of multidimensional indexing by non-tuple indices changed in NumPy 1.23: NumPy strictly only accepts tuples as multidimensional indexes.
In previous versions of NumPy, it appears that jnp.array([-1, 0]) was being interpreted as a tuple, something that we now must make explicit.
[NumPy] Fix test failure in distrax under NumPy 1.23, attempt 2.
The semantics of multidimensional indexing by non-tuple indices changed in NumPy 1.23: NumPy strictly only accepts tuples as multidimensional indexes.
In previous versions of NumPy, it appears that jnp.array([-1, 0]) was being interpreted as a tuple, something that we now must make explicit.