patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Fix _check_shape str formatting for variadics #186

Closed asford closed 4 months ago

asford commented 4 months ago

__instancecheck_str__ raises an exception from message formatting in a specific case of variadic shape mismatch. This causes type checkers to raise exceptions, rather than properly type check, at runtime.

This is reproduced via:

import numpy as np

from jaxtyping import Shaped

ta = np.arange(3 * 10).reshape((3, 10))

# passes with true
isinstance(ta, Shaped[np.ndarray, "*p b c"])

# raises
isinstance(ta[0], Shaped[np.ndarray, "*p b c"])

for the error:

Traceback (most recent call last):
  File "/workspaces/jaxtyping/repro.py", line 11, in <module>
    isinstance(ta[0], Shaped[np.ndarray, "*p b c"])
  File "/workspaces/jaxtyping/jaxtyping/_array_types.py", line 156, in __instancecheck__
    return cls.__instancecheck_str__(obj) == ""
  File "/workspaces/jaxtyping/jaxtyping/_array_types.py", line 205, in __instancecheck_str__
    check = cls._check_shape(obj, single_memo, variadic_memo, arg_memo)
  File "/workspaces/jaxtyping/jaxtyping/_array_types.py", line 232, in _check_shape
    return f"this array has {obj.ndim} dimensions, which is fewer than {len(cls.dims - 1)} that is the minimum expected by the type hint"  # noqa: E501
TypeError: unsupported operand type(s) for -: 'tuple' and 'int'
patrick-kidger commented 4 months ago

LGTM! Thank your for the fix. I'll do a quick release with this one in.