patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.13k stars 56 forks source link

Runtime type checking via `typeguard` causes `TypeError` due to array's having type `DeviceArray`. #33

Open jaymody opened 2 years ago

jaymody commented 2 years ago

I'm trying to use jaxtyping with runtime type checking via typeguard as described here. Here's my code:

import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped
from typeguard import typechecked as typechecker

@jaxtyped
@typechecker
def foo(
    x: Float[Array, "n"],
    y: Float[Array, "n"],
) -> Float[Array, "n"]:
    return x + y

print(foo(jnp.arange(10), jnp.arange(10)))

However when I run the above script, I get the following error:

Traceback (most recent call last):
  File "/Users/jay/playground/myscript.py", line 14, in <module>
    print(foo(jnp.arange(10), jnp.arange(10)))
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/jaxtyping/decorator.py", line 41, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 1032, in wrapper
    check_argument_types(memo)
  File "/Users/jay/playground/.venv/lib/python3.9/site-packages/typeguard/__init__.py", line 875, in check_argument_types
    raise TypeError(*exc.args) from None
TypeError: type of argument "x" must be jaxtyping.Float[ndarray, 'n']; got jaxlib.xla_extension.DeviceArray instead

Steps to reproduce my python environment (Note: I'm running this on an M1 Macbook Pro with macOS Monterey 12.2 (21D49)):

$ python -V
Python 3.9.10

$ python -m venv .venv

$ source .venv/bin/activate

$ python -m pip install --upgrade pip

$ python -m pip install "jax[cpu]==0.3.17" "jaxtyping==0.2.5"
jaymody commented 2 years ago

Ah, so I'm realizing it's because jnp.arange by default returns an array of type int. If I change it to print(foo(jnp.arange(10)*1.0, jnp.arange(10)*1.0)) I no longer get an error. Wondering if the error message can be more descriptive, or if this quirk is documented somewhere? Error message is a bit misleading.

patrick-kidger commented 2 years ago

Right; something similar came up in #6. Indeed it would be great if the error message could include more information, but it's the typechecker that's raising the error (in this case typeguard) -- not jaxtyping. (All jaxtyping does is provide the types themselves.)

FWIW my usual approach to debugging this it to rerun with the debugger, so that I can check what types were passed myself. This can be done with either of:

python -m pdb -c continue your_script.py
ipython your_script.py --pdb

(I'd like a better solution to this too.)

jaymody commented 2 years ago

Yeah, that's the workaround I'm using as well to check the shapes and types if an error comes up. Maybe it's worth documenting this in API.md? I missed #6 in my search for a solution (which is on me tbh), but might be useful for the next person that will inevitably come across this without thoroughly checking the issues on github.