An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer jax.Array accurately as a type in many more cases.
[JAX] Fix incorrect type annotations.
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer
jax.Array
accurately as a type in many more cases.