patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Update deprecated method usage #218

Closed groszewn closed 3 weeks ago

groszewn commented 4 weeks ago

https://github.com/google/jax/pull/19930 deprecated jax.tree_map in favor of jax.tree.map.

Additionally, this adds some configuration to skip tests that depend on sys.executable in environments where it is not set.

patrick-kidger commented 3 weeks ago

LGTM! Thank you for the fix. Thought I'd finally snagged all those pesky jax.tree_maps but apparently not!