patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Annotations for tensors with dynamics dimensions #225

Closed martenlienen closed 1 week ago

martenlienen commented 2 weeks ago

Hi, what would be the idiomatic way to describe the return shape of, for example, torch.zeros in jaxtyping?

def zeros(size: tuple[int, ...]) -> Float[Tensor, "what do I put here?"]:
    pass

How can I encode that the output shape is determined by the number of elements of the size tuple?

patrick-kidger commented 1 week ago

Hey there! Unfortunately this isn't supported -- it'd require a fairly tricky rewrite of some internals: https://github.com/patrick-kidger/jaxtyping/pull/140#issuecomment-1804886283

martenlienen commented 1 week ago

Thanks for clearing that up!