Open jaymody opened 2 years ago
Ah, so I'm realizing it's because jnp.arange
by default returns an array of type int. If I change it to print(foo(jnp.arange(10)*1.0, jnp.arange(10)*1.0))
I no longer get an error. Wondering if the error message can be more descriptive, or if this quirk is documented somewhere? Error message is a bit misleading.
Right; something similar came up in #6. Indeed it would be great if the error message could include more information, but it's the typechecker that's raising the error (in this case typeguard) -- not jaxtyping. (All jaxtyping does is provide the types themselves.)
FWIW my usual approach to debugging this it to rerun with the debugger, so that I can check what types were passed myself. This can be done with either of:
python -m pdb -c continue your_script.py
ipython your_script.py --pdb
(I'd like a better solution to this too.)
Yeah, that's the workaround I'm using as well to check the shapes and types if an error comes up. Maybe it's worth documenting this in API.md
? I missed #6 in my search for a solution (which is on me tbh), but might be useful for the next person that will inevitably come across this without thoroughly checking the issues on github.
I'm trying to use
jaxtyping
with runtime type checking viatypeguard
as describedhere
. Here's my code:However when I run the above script, I get the following error:
Steps to reproduce my python environment (Note: I'm running this on an M1 Macbook Pro with macOS Monterey 12.2 (21D49)):