patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Disabling JAX import #178

Closed MilesCranmer closed 7 months ago

MilesCranmer commented 7 months ago

Hey @patrick-kidger! 👋

In the lines here:

https://github.com/patrick-kidger/jaxtyping/blob/8de8c0bb68c41dbd7d80a6e373eacae1229efe6a/jaxtyping/_array_types.py#L39-L48

I was wondering if we can add a mechanism for disabling the jax import even if there are no errors? Right now it will automatically load jax if jax is installed in the environment, even if I am not using jax in my code. My group ran into this issue debugging why importing jaxtyping would crash python in pytorch code – turns out that we had a very old jax installed, which was being loaded.

I think an environment variable might work for this, what do you think?

patrick-kidger commented 7 months ago

Hey Miles!

Sure, I think we can make this happen. In fact I think we can do slightly better -- we can automatically setup jaxtyping to simply check whether JAX is installed, without actually importing it. (Using importlib.util.find_spec.) Then arrange for all of our JAX imports to happen dynamically (mostly by using a module-level __getattr__).

I've just pushed a new commit to the dev branch that does exactly this. Give it a try and let me know what you think?

MilesCranmer commented 7 months ago

Awesome, that seems to work! Thanks for the quick fix :)