Closed PaulScemama closed 1 year ago
That sounds like an interesting idea! I'm not sure how that would be checked at runtime -- do torch.Distribution
objects offer a way to check the vectors they wrap?
I'm not sure at the moment 😅 but I will take a look into it the rest of the week and get back to you?
@patrick-kidger there is both
event_shape
: Returns the shape of a single sample (without batching).batch_shape
: Returns the shape over which parameters are batched.If you're interested I could give a shot at starting a minimally sufficient example?
Hmm. Returning to this, I think this might be a can of worms -- this is naturally extended to also allowing annotations like MyModule["foo", "bar"]
, but this e.g. now conflicts with more-typical typing.Generic
use-cases. (And likewise, someone might equally well want to do class MyDistribution(torch.Distribution, typing.Generic[T, S])
.)
On a more practical level, torchtyping is now really only in maintenance mode, in favour of jaxtyping (which despite the name also supports PyTorch and does not depend on JAX).
All-in-all I think I'm inclined to leave this feature unimplemented, I'm afraid.
That's no worries at all. I appreciate the explanation. I will definitely give jaxtyping a look now as well. Thanks again!
Hi,
Thank you for this great library! It has been very helpful. I was wondering if you've considered supporting torch.Distribution datatypes. The idea being that you could do something like
func() -> Distribution["state_dim"]
, so that it is clear that the output is a distribution over vectors of size"state_dim"
. What do you think?