patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Support distributions #44

Closed PaulScemama closed 1 year ago

PaulScemama commented 1 year ago

Hi,

Thank you for this great library! It has been very helpful. I was wondering if you've considered supporting torch.Distribution datatypes. The idea being that you could do something like func() -> Distribution["state_dim"], so that it is clear that the output is a distribution over vectors of size "state_dim". What do you think?

patrick-kidger commented 1 year ago

That sounds like an interesting idea! I'm not sure how that would be checked at runtime -- do torch.Distribution objects offer a way to check the vectors they wrap?

PaulScemama commented 1 year ago

I'm not sure at the moment 😅 but I will take a look into it the rest of the week and get back to you?

PaulScemama commented 1 year ago

@patrick-kidger there is both

If you're interested I could give a shot at starting a minimally sufficient example?

patrick-kidger commented 1 year ago

Hmm. Returning to this, I think this might be a can of worms -- this is naturally extended to also allowing annotations like MyModule["foo", "bar"], but this e.g. now conflicts with more-typical typing.Generic use-cases. (And likewise, someone might equally well want to do class MyDistribution(torch.Distribution, typing.Generic[T, S]).)

On a more practical level, torchtyping is now really only in maintenance mode, in favour of jaxtyping (which despite the name also supports PyTorch and does not depend on JAX).

All-in-all I think I'm inclined to leave this feature unimplemented, I'm afraid.

PaulScemama commented 1 year ago

That's no worries at all. I appreciate the explanation. I will definitely give jaxtyping a look now as well. Thanks again!