Closed ghost closed 2 years ago
So presumably the batch dimensions should match as well: i.e. the full condition you want to match is input.shape == target.shape
?
This can be done using by naming a group of dimensions, rather than just a single dimension:
def mean_squared_error(input: TensorType["shape": ...], target: TensorType["shape": ...]):
Using a literal ...
.
Thanks! This is exactly what I needed.
Consider this function
The above only allows batches to contain 1-element values (i.e. scalars).
I would like to ensure that the shape of items in the input batch is the same as the shape of items in the target batch, i.e
input.shape[1:] == target.shape[1:]
.I don't want to hardcode the number of dimensions like for example a batch containing images:
input: TensorType["batch", "c", "h", "w"]
.Is this currently possible?