patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Support for NestedTensors #185

Open jaanli opened 7 months ago

jaanli commented 7 months ago

Currently NestedTensors give errors (examples - https://pytorch.org/tutorials/prototype/nestedtensor.html and https://pytorch.org/docs/stable/nested.html)

File <@beartype(models.transformer.Attention.forward) at 0x7f7ac394d430>:23, in forward(__beartype_object_221882800, __beartype_get_violation, __beartype_conf, __beartype_object_140169208358272, __beartype_object_140169211732352, __beartype_func, *args, **kwargs)
...
RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue.

Just an idea, could be helpful for more folks! :) this library has already saved me from many bugs btw 🙏

patrick-kidger commented 7 months ago

Thank you for your kind words!

Honestly, I'm not sure how nested tensors would work with something like jaxtyping. They don't have rectangular shapes, after all. Do you have a suggestion?

jaanli commented 7 months ago

Currently the library supports a wildcard operator (e.g. Float[Tensor, "*batch seqlen"]) to denote variable-length dimensions.

NestedTensors have "ragged" dimension-like analogs.

I don't think support is necessary at this point, but if a library like Nested Tensors from PyTorch ends up popular perhaps a similar wildcard like ~ could be used to indicate ragged dimensions?

patrick-kidger commented 7 months ago

Probably something like the existing _ may suffice, indicating a dimension that isn't checked. IIUC nested tensors can be nested inside themselves though, so practically speaking I think you end up entirely forfeiting any notion of an 'overall array shape'.