Open jaanli opened 7 months ago
Thank you for your kind words!
Honestly, I'm not sure how nested tensors would work with something like jaxtyping. They don't have rectangular shapes, after all. Do you have a suggestion?
Currently the library supports a wildcard operator (e.g. Float[Tensor, "*batch seqlen"]
) to denote variable-length dimensions.
NestedTensors have "ragged" dimension-like analogs.
I don't think support is necessary at this point, but if a library like Nested Tensors from PyTorch ends up popular perhaps a similar wildcard like ~
could be used to indicate ragged dimensions?
Probably something like the existing _
may suffice, indicating a dimension that isn't checked.
IIUC nested tensors can be nested inside themselves though, so practically speaking I think you end up entirely forfeiting any notion of an 'overall array shape'.
Currently NestedTensors give errors (examples - https://pytorch.org/tutorials/prototype/nestedtensor.html and https://pytorch.org/docs/stable/nested.html)
Just an idea, could be helpful for more folks! :) this library has already saved me from many bugs btw 🙏