patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Arbitrary number of dimensions - but check they are same over the argument tensors #31

Closed ghost closed 2 years ago

ghost commented 2 years ago

Consider this function

@typechecked
def mean_squared_error(input: TensorType["batch"], target: TensorType["batch"]):
    d = input - target
    d = d * d
    return torch.mean(d)

The above only allows batches to contain 1-element values (i.e. scalars).

I would like to ensure that the shape of items in the input batch is the same as the shape of items in the target batch, i.e input.shape[1:] == target.shape[1:].

I don't want to hardcode the number of dimensions like for example a batch containing images: input: TensorType["batch", "c", "h", "w"].

Is this currently possible?

patrick-kidger commented 2 years ago

So presumably the batch dimensions should match as well: i.e. the full condition you want to match is input.shape == target.shape?

This can be done using by naming a group of dimensions, rather than just a single dimension:

def mean_squared_error(input: TensorType["shape": ...], target: TensorType["shape": ...]):

Using a literal ....

ghost commented 2 years ago

Thanks! This is exactly what I needed.