patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.19k stars 62 forks source link

Splicing / variadic symbolic expressions #265

Open martenlienen opened 2 days ago

martenlienen commented 2 days ago

Would it be possible to make the following code snippet work?

import torch
from beartype import beartype
from jaxtyping import Float, jaxtyped
from torch import Tensor

class A:
    def __init__(self, shape: tuple[int, ...]):
        self.shape = shape

    @jaxtyped(typechecker=beartype)
    def forward(self, x: Float[Tensor, "... {self.shape}"]) -> Float[Tensor, "..."]:
        return x.flatten(start_dim=-len(self.shape)).sum(dim=-1)

a = A((3, 10, 5))
x = torch.randn((7, 3, 4, 5))
print(a.forward(x))

At the moment it does not work as far as I can tell, because {self.shape} is only matched against a single dimension of x. Is there a way to evaluate the expression and splice in the tuple value into the type before the type gets matched against the dimensions? Maybe with something like a *{self.shape} syntax?

patrick-kidger commented 2 days ago

Yup, this is a known issue. I don't have a nice way to fix this right now -- this is quite a complicated corner of jaxtyping! -- but I'd be happy to take a PR if someone feels like taking this on.

If need be you can maybe do something like str(self.shape).replace(",", " ")[1:-1] but that's obviously pretty messy.