Closed MilesCranmer closed 7 months ago
Hey Miles!
Sure, I think we can make this happen. In fact I think we can do slightly better -- we can automatically setup jaxtyping to simply check whether JAX is installed, without actually importing it. (Using importlib.util.find_spec
.) Then arrange for all of our JAX imports to happen dynamically (mostly by using a module-level __getattr__
).
I've just pushed a new commit to the dev
branch that does exactly this. Give it a try and let me know what you think?
Awesome, that seems to work! Thanks for the quick fix :)
Hey @patrick-kidger! 👋
In the lines here:
https://github.com/patrick-kidger/jaxtyping/blob/8de8c0bb68c41dbd7d80a6e373eacae1229efe6a/jaxtyping/_array_types.py#L39-L48
I was wondering if we can add a mechanism for disabling the jax import even if there are no errors? Right now it will automatically load jax if jax is installed in the environment, even if I am not using jax in my code. My group ran into this issue debugging why importing jaxtyping would crash python in pytorch code – turns out that we had a very old jax installed, which was being loaded.
I think an environment variable might work for this, what do you think?