google-deepmind / distrax

Apache License 2.0
535 stars 32 forks source link

[JAX] Fix incorrect type annotations. #248

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

[JAX] Fix incorrect type annotations.

An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer jax.Array accurately as a type in many more cases.