patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Generalizing the Library #39

Open corwinjoy opened 1 year ago

corwinjoy commented 1 year ago

Hello and thanks for publishing this library! I've really enjoyed reading the design and discussion documents you have posted. However, I am now trying to apply this library in a somewhat broader context. Essentially, I am hoping to use it to improve the linear operator library. linear_operator. The idea of this library is to abstract how tensors are stored to be able to perform matrix operations much more efficiently. I'd really like to use torchtyping to add dimensional and storage type checks to help squash bugs in this code. Unfortunately, torchtyping is configured to run exactly on torch.Tensor objects. My first attempt was just to hack the library to pull out a few class checks. But, doing more reading, I feel like torchtyping could be cleanly improved by using protocols. PEP 544 – Protocols: Structural subtyping (static duck typing). The idea would be to have the library use an abstract tensor protocol rather than tensor directly. This would make the library much more general and I think it could help cleanup the code by making it explicit as to what tensor fields are being used. What do you think / do you have any suggestions on how to add this? @dannyfriar @m4rs-mt

patrick-kidger commented 1 year ago

You might find my other project jaxtyping interesting. This is able to handle other array/tensor types -- at minimum it is tested to be able to handle JAX+numpy+pytorch+tensorflow. This is possible because it uses a slightly different syntax, so that you can specify the array/tensor type explicitly, e.g. Float[torch.Tensor, "batch channel"].

Just as soon as I find the time, I actually intend to update TorchTyping to follow the same model -- essentially by copy-pasting and tweaking the code from jaxtyping. [Whilst setting up some backward-compatible aliases, of course.] See also the discussions here and here on what the plan is.

If you really like TorchTyping then I'd be happy to accept a PR doing this. Alternatively, you may be able to use jaxtyping directly in your project today. (Even though you aren't using JAX!)

corwinjoy commented 1 year ago

Interesting and thanks for the quick response! I will take a closer look at jaxtyping.

corwinjoy commented 1 year ago

Thanks again for the suggestion. After looking at jaxtyping I decided that I didn't want to introduce a dependency on Jax since that adds quite a bit of complexity. So, I have gone ahead and submitted a PR where I switch the code to use a protocol class rather than torch.Tensor directly. I think this helps to generalize things in a clean way...