patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Symbolic expressions example doesn't run #181

Closed kvablack closed 4 months ago

kvablack commented 4 months ago

Hi, I'm trying to run the basic symbolic expressions example and I'm getting an error.

import jax
from jaxtyping import Float, Array
from typeguard import typechecked

@typechecked
def full(size: int, fill: float) -> Float[Array, "{size}"]:
    return jax.numpy.full((size,), fill)

full(10, 1.0)

# AnnotationError: Cannot process symbolic axis '{size}' as some axis names have not been processed. In practice you should usually only use symbolic axes in annotations for return types, referring only to axes annotated for arguments.

jaxtyping 0.2.25, typeguard 2.13.3

patrick-kidger commented 4 months ago

You need a @jaxtyped wrapper as well. Typically this is done via the single decorator @jaxtyped(typechecker=typechecked).

kvablack commented 4 months ago

My bad, thanks!