patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.13k stars 58 forks source link

Installing `jaxtyping` makes Pytest slow(er) #220

Open thejcannon opened 3 months ago

thejcannon commented 3 months ago

👋

On my machine:

$ time python -c 'import jaxtyping'
python -c 'import jaxtyping'  3.66s user 4.95s system 312% cpu 2.757 total

so, because jaxtyping has the pytest entry point, pytest will load jaxtyping._pytest_plugin which loads jaxtyping which takes 2.75s.

Would you consider putting the pytest-related code under a different tree?

patrick-kidger commented 3 months ago

Hey there! Hmm, I don't think I understand this one. Doing just python -c 'import jaxtyping' shouldn't load pytest at all. In fact with this not even jaxtyping._pytest_plugin is loaded -- this entry point is accessed only by pytest itself.

thejcannon commented 3 months ago

Right, the two things I was trying to show are:

So, any pytest invocation is now 2.75s slower.

patrick-kidger commented 3 months ago

Ah, hmm. Perhaps it'll be the way that jaxtyping checks if JAX/Equinox are available. I'm speculating here as jaxtyping itself is tiny.

If you can, try benchmarking what happens when {jax, equinox} x {are, aren't} available to be imported and see how that changes things? Maybe we can optimise that somehow.

(FWIW I frequently use jaxtyping + pytest and don't see this kind of multisecond lag, but maybe there's some kind of machine-to-machine variation there.)