patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

jax dependency error when jax is not installed #188

Closed jsternabsci closed 4 months ago

jsternabsci commented 4 months ago

jaxtyping 0.2.26 introduced a bug for using the @jaxtyped decorator without jax installed.

Error:

/opt/conda/envs/testenv-20743/lib/python3.11/site-packages/jaxtyping/_decorator.py:192: in jaxtyped
    if _tb_flag and importlib.util.find_spec("jax._src.traceback_util") is not None:
E   ModuleNotFoundError: No module named 'jax'

Breaking change: https://github.com/patrick-kidger/jaxtyping/compare/v0.2.25...v0.2.26#diff-f792a47fc41c0cf008332f62022e6136a8f6c6e0514c743eedc7df172e519ce5R192

patrick-kidger commented 4 months ago

Thanks for the report! This should be fixed in #188, and I've just done a new release to include this.

ar0ck commented 4 months ago

My environment has jax but not jaxlib (don't ask me why :thinking:).

This means that importlib.util.find_spec("jax") succeeds, but importlib.util.find_spec("jax._src.traceback_util") fails with ModuleNotFoundError: jax requires jaxlib to be installed..

Thoughts on this? Should jaxtyping verify that jax not only exists, but is import-able? Or it basically just user error to have jax without jaxlib.

patrick-kidger commented 4 months ago

Aha! That's unfortunate. I think we might as well add a check for jaxlib as well. I'd be happy to take a PR on that?

ar0ck commented 4 months ago

Will do!