google-deepmind / rlax

https://rlax.readthedocs.io
Apache License 2.0
1.24k stars 85 forks source link

[JAX] Fix incorrect type annotations. #129

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

[JAX] Fix incorrect type annotations.

An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer jax.Array accurately as a type in many more cases.