I am trying to typecheck a leading dimension in a pytree, but I get an error when I combine a leading integer with .... Is this expected?
import jax
from jaxtyping import PyTree, Array
pytree = [
jax.numpy.zeros((1, 10)),
(jax.numpy.zeros((1, 3, 4)), jax.numpy.array((1,)))
]
RecurrentState = PyTree[Array, "Time ..."] # This is fine
SingleRecurrentState = PyTree[Array, "1 ..."] # This raises an error
The resulting error is
Traceback (most recent call last):
File "/Users/smorad/code/memoryx/test.py", line 10, in <module>
SingleRecurrentState = PyTree[Array, "1 ..."]
~~~~~~^^^^^^^^^^^^^^^^
File "/Users/smorad/miniforge3/envs/memoryx/lib/python3.12/site-packages/jaxtyping/_pytree_type.py", line 218, in __getitem__
raise ValueError(
ValueError: The string `struct` in `jaxtyping.PyTree[leaftype, struct]` must be be a whitespace-separated sequence of identifiers, e.g. `jaxtyping.PyTree[leaftype, 'T']` or `jaxtyping.PyTree[leaftype, 'foo bar']`.
(Here, 'identifier' is used in the same sense as in regular Python, i.e. a valid variable name.)
Got piece '1' in overall structure '1 ...'.
I am trying to typecheck a leading dimension in a pytree, but I get an error when I combine a leading integer with
...
. Is this expected?The resulting error is