patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Leading integer and ellipses in pytree raises error #227

Closed smorad closed 3 months ago

smorad commented 3 months ago

I am trying to typecheck a leading dimension in a pytree, but I get an error when I combine a leading integer with .... Is this expected?

import jax
from jaxtyping import PyTree, Array

pytree = [
    jax.numpy.zeros((1, 10)),
    (jax.numpy.zeros((1, 3, 4)), jax.numpy.array((1,)))
]

RecurrentState = PyTree[Array, "Time ..."] # This is fine
SingleRecurrentState = PyTree[Array, "1 ..."] # This raises an error

The resulting error is

Traceback (most recent call last):
  File "/Users/smorad/code/memoryx/test.py", line 10, in <module>
    SingleRecurrentState = PyTree[Array, "1 ..."]
                           ~~~~~~^^^^^^^^^^^^^^^^
  File "/Users/smorad/miniforge3/envs/memoryx/lib/python3.12/site-packages/jaxtyping/_pytree_type.py", line 218, in __getitem__
    raise ValueError(
ValueError: The string `struct` in `jaxtyping.PyTree[leaftype, struct]` must be be a whitespace-separated sequence of identifiers, e.g. `jaxtyping.PyTree[leaftype, 'T']` or `jaxtyping.PyTree[leaftype, 'foo bar']`.
(Here, 'identifier' is used in the same sense as in regular Python, i.e. a valid variable name.)
Got piece '1' in overall structure '1 ...'.
patrick-kidger commented 3 months ago

You want PyTree[Shaped[Array, "1 ..."]]. The second argument to PyTree encodes the structure of the PyTree and has nothing to do with array shapes. :)

smorad commented 3 months ago

Thanks for the quick response!