Fixed type errors to unblock an internal type annotations refactoring in JAX
Some JAX internal used Any instead of Array or in their type annotations.
https://github.com/google/jax/pull/17760 changed these to alias jax.Array and
uncovered type errors fixed here.
Fixed type errors to unblock an internal type annotations refactoring in JAX
Some JAX internal used Any instead of Array or in their type annotations. https://github.com/google/jax/pull/17760 changed these to alias jax.Array and uncovered type errors fixed here.