patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

How can I inspect the jaxtyping bindings? #158

Open smathis-absci opened 9 months ago

smathis-absci commented 9 months ago

Hi and thank you for this fantastic work!

When debugging my code, which uses beartype and jaxtyping, I repeatedly encountered issues where I wanted to inspect the bindings of jaxtyping at a particular point of the code. However, I could not find an easy way in the documentation that exposes the bindings.

What would be the best way to inspect jaxtyping bindings at a point of the code, to use e.g. when stepping through with a debugger?

Thank you for your help!

smathis-absci commented 9 months ago

Oh I think I found a way:

from jaxtyping._storage import get_shape_memo

...[your code]...
print(get_shape_memo()[:3])
...[your code]...
patrick-kidger commented 9 months ago

Haha, well sleuthed! I've just created #159 to offer this as a public API, and in a human-readable fashion.

I hope that helps!