pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.85k stars 22.33k forks source link

Add `output_size` argument to `Upsample` forward method (just like for `ConvTranspose` Modules) #71877

Open jenkspt opened 2 years ago

jenkspt commented 2 years ago

🚀 The feature, motivation and pitch

Motivated by the same reasons ConvTranspose modules have an output_size argument.

The problem is that multiple tensor sizes will downsample to the same size For example:

>>> import torch
>>> import torch.nn.functional as F
>>> F.interpolate(torch.rand(1, 1, 10, 10), scale_factor=.5).shape
torch.size([1, 1, 5, 5])
>>> F.interpolate(torch.rand(1, 1, 11, 11), scale_factor=.5).shape
torch.size([1, 1, 5, 5])

When reversing this operation with an upsampling interpolation operation, an output_size argument is required to reconstruct the original input size. The desired functionality is as follows -- to reconstruct the original size of [1, 1, 11, 11], we would do:

nn.Upsample(scale_factor=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 11, 11])

Implementation: _ConvTransposeNd already has a utility method _output_padding for calculating the padding of the output tensor when the output_size argument is provided. While Upsample doesn't require most of the arguments (e.g. padding, dilation, kernel_size ...) It seems like a reasonable solution would be to extract the _output_padding method into a separate utility function to use with the Upsample module as well.

I'm happy to create a pull request if this is a desired feature.

Alternatives

No response

Additional context

The motivation for this feature request is to get Unet models (specifically segmentation_models_pytorch Unet model) to work with arbitrarily sized input images. Currently these models throw errors at the upsampling op for certain input images sizes -- due to the round-off problem mentioned above. However this feature is useful for any model that include downsampling + upsampling interpolation layers.

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

jbschlosser commented 2 years ago

Hey @jenkspt, thanks for the request! Seems like the heart of the issue is that the size arg currently supported by nn.Upsample is too restrictive when taken in via the constructor rather than during the forward pass, as the desired output size depends on the input size passed during forward.

However, I think the semantics of the proposed output_size arg may be confusing; for example, it seems to conflict with scale_factor for this example case:

nn.Upsample(scale_factor=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 30, 30])

As a workaround, I'd recommend using F.interpolate with a size specified for upsampling rather than nn.Upsample. Does this solve the problem for you?

jenkspt commented 2 years ago

Hey @jbschlosser thanks for the feedback.

To clarify -- The issue isn't that size is too restrictive in the constructor. size and output_size would be mutually exclusive arguments that serve different purposes

_FYI: we can't use the size argument here because we wan't to keep scalefactor constant & handle arbitrarily sized inputs

skip = torch.rand(1, 1, 32, 32)                                # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3)                      # --> torch.Size(1, 1, 10, 10)
x = nn.Upsample(scale_factor=1/3).forward(x)                   # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1)                                # Throws error due to size mismatch

having the output_size argument fixes this problem

skip = torch.rand(1, 1, 32, 32)                                         # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3)                               # --> torch.Size(1, 1, 10, 10)
x = nn.Upsample(scale_factor=1/3).forward(x, output_size=skip.shape)    # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1)                                         # Throws error due to size mismatch

To answer your questions:

"However, I think the semantics of the proposed output_size arg may be confusing; for example, it seems to conflict with scale_factor for this example case:"

nn.Upsample(scale_factor=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 30, 30])

This example should throw an error since the only valid output sizes are 10 & 11. The semantics are the same as ConvTranspose2D:

>> nn.ConvTranspose2d(1, 1, 2, stride=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 30, 30])
ValueError: requested an output size of [30, 30], but valid sizes range from [10, 10] to [11, 11] (for an input of torch.Size([5, 5]))

.

"As a workaround, I'd recommend using F.interpolate with a size specified for upsampling rather than nn.Upsample. Does this solve the problem for you?"

As mentioned above -- using size doesn't allow for a constant scale_factor with arbitrarily sized input images/tensors

jbschlosser commented 2 years ago

having the output_size argument fixes this problem

skip = torch.rand(1, 1, 32, 32)                                         # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3)                               # --> torch.Size(1, 1, 10, 10)
x = nn.Upsample(scale_factor=1/3).forward(x, output_size=skip.shape)    # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1)                                         # Throws error due to size mismatch

... As mentioned above -- using size doesn't allow for a constant scale_factor with arbitrarily sized input images/tensors

@jenkspt It's unclear to me why you couldn't just do this instead:

skip = torch.rand(1, 1, 32, 32)            # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3)  # --> torch.Size(1, 1, 10, 10)
x = F.interpolate(x, size=skip.shape[-2:]) # --> torch.Size(1, 1, 32, 32)
x = torch.cat([skip, x], dim=1)            # --> torch.Size(1, 2, 32, 32)
jenkspt commented 2 years ago

Here's one input size:

skip = torch.rand(1, 1, 32, 32)            # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3)  # --> torch.Size(1, 1, 10, 10)
x = F.interpolate(x, size=skip.shape[-2:]) # --> torch.Size(1, 1, 32, 32)
x = torch.cat([skip, x], dim=1)            # --> torch.Size(1, 2, 32, 32)

The scale factor calculated for x = F.interpolate(x, size=skip.shape[-2:]) would be 1 / 3.2 == 0.3125

If we use a different input size:

skip = torch.rand(1, 1, 30, 30)            # --> torch.Size(1, 1, 30, 30)
x = F.interpolate(skip, scale_factor=1/3)  # --> torch.Size(1, 1, 10, 10)
x = F.interpolate(x, size=skip.shape[-2:]) # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1)            # --> torch.Size(1, 2, 30, 30)

The scale factor calculated for x = F.interpolate(x, size=skip.shape[-2:]) would be 1 / 3 == 0.3333

The intention here is not to change the scaling factor with different input sizes.

jbschlosser commented 2 years ago

The intention here is not to change the scaling factor with different input sizes.

@jenkspt Are you expecting (zero?) padding around the edges to maintain the scale factor while achieving the correct input size?

jenkspt commented 2 years ago

Yes I would expect zero padding. If F.interpolate or nn.Upsample had a padding_mode argument I would expect to use that (but they don't).