patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

Support Any for shape #5

Closed AdilZouitine closed 3 years ago

AdilZouitine commented 3 years ago

Hi, I would like to thank you for this cool library. I was desperate not to find a shape typing for pytorch and I had planned to code it myself if it didn't exist.

I think your api is great, however I find that specifying the dimension of any shape to -1 is not very intuitive (I saw that you have many other ways to declare it). One idea is to declare a dimension with any shape using typing.Any. As the library nbtyping does :

from typing import Any
import numpy as np 
from nptyping import NDArray
NDArray[(3, 3, Any), np.float32]

In this case we have typed our array with no constraints on the last dimension. If we apply this modification to your library:

from typing import Any
from torchtyping import TensorType
import torch

TensorType[3, 3, Any, torch.float32]
# Instead of 
TensorType[3, 3, -1, torch.float32]

What do you think of this? If you're interested I can try to make a pull request!

I thank you again for developing this wonderful library.

patrick-kidger commented 3 years ago

Yep, that sounds good to me. I'd be very happy to accept a PR on this.

patrick-kidger commented 3 years ago

In passing: The syntax currently used is TensorType[3, 3, -1, torch.float32] not TensorType[(3, 3, -1), torch.float32] (taken from your example). If you particularly want the second syntax then I'd be happy to have that in as well.

The different arguments in the [] are distinguished by type rather than position. This was a deliberate choice to

AdilZouitine commented 3 years ago

In passing: The syntax currently used is TensorType[3, 3, -1, torch.float32] not TensorType[(3, 3, -1), torch.float32] (taken from your example).

The different arguments in the [] are distinguished by type rather than position. This was a deliberate choice to

  • (a) make it easy to do just TensorType[float] rather than something like TensorType[(...,), float].
  • (b) make it possible to use slice syntax like TensorType["channels": 4] to be able to put both name and size to a dimension.

I understand your choice! (I will therefore modify my example) 😄

patrick-kidger commented 3 years ago

Closed by #6.