patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Checking the first dimensions of a tensor #25

Open adrianjav opened 3 years ago

adrianjav commented 3 years ago

Hi!

I just found torchtyping a few days ago, and I am enjoying it so far. However, I am a bit confused when it comes to one particular use-case: checking if the arguments of a function share the same first dimensions.

For example, if I try to write a function such as batch-wise scalar multiplication:

def batchwise_multiply(data: TensorType['B', ...], weights: TensorType['B']):
    pass

I get a NotImplementedError: Having dimensions to the left of ... is not currently supported.

Why is such a behaviour not implemented? What is the difference from performing the same operation on the right? While I haven't checked the code, to the best of my understanding if TensorType[..., 'B'] is supported, then if you detect a situation like TensorType[..., 'B'], you should be able to reuse the same code but reading the tensors backwards, isn't it?

I feel this feature would be huge for the library. At least with my programming conventions, I tend to put common dimensions in leading positions so that later I can unpack tensors using the * operator.

patrick-kidger commented 3 years ago

Hi there; I'm glad you like the library.

So the answer is that handling the general case for ... is complicated. For example someone might want to match against TensorType[..., "a", ...]. In this case there are are multiple valid parses: "a" could correspond to any dimension present, and then we'd need to keep track of multiple different options when comparing against the other arguments -- which might exhibit similar complexities themselves.

As a result, torchtyping relies on being able to do a right-to-left iteration through the dimensions, matching up named dimensions as it goes along. This means torchtyping only has to keep track of a single "possibility" as it considers each argument in turn. In the general case we'd need to consider many possibilities and then start comparing which are consistent across all arguments considered.

This is the case torchtyping supports because putting batch dimensions at the start is conventional (and also typically fastest). Other special cases -- like the one you describe -- should still be possible, and would probably require their own resolution algorithms. (In this case, the same as the current algorithm, but with a left-to-right iteration.)

As a fix, you can use *tensor.unbind(dim) in place of *tensor in order to unpack down arbitrary dimensions.

If you have the energy this is something I'd be happy to accept a PR on.

adrianjav commented 3 years ago

Thanks for the quick response.

I see, I didn't think about the general case. I knew about the unbind method already, but I was trying to avoid using it as much as possible. The code just becomes harder to read in general.

For the moment, I will leave those cases unchecked and if I find some time in the following weeks I will try to solve this case myself and send a PR 💪🏼

shaperilio commented 2 years ago

@patrick-kidger I understand how the case of TensorType[..., "a", ...] is arguably impossible to resolve, but did you just make an arbitrary decision to support ... only on the left? I.e., couldn't you also support ... on the right, but not both sides?

patrick-kidger commented 2 years ago

So it's definitely possible to support ... on the right; it just isn't implemented. I chose the left just because in my experience it's more common to have an arbitrary number of batch dimensions on the left, than it is to have an arbitrary number of channel dimensions on the right.