Open jenkspt opened 2 years ago
Hey @jenkspt, thanks for the request! Seems like the heart of the issue is that the size
arg currently supported by nn.Upsample
is too restrictive when taken in via the constructor rather than during the forward pass, as the desired output size depends on the input size passed during forward.
However, I think the semantics of the proposed output_size
arg may be confusing; for example, it seems to conflict with scale_factor
for this example case:
nn.Upsample(scale_factor=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 30, 30])
As a workaround, I'd recommend using F.interpolate
with a size
specified for upsampling rather than nn.Upsample
. Does this solve the problem for you?
Hey @jbschlosser thanks for the feedback.
To clarify -- The issue isn't that size
is too restrictive in the constructor. size
and output_size
would be mutually exclusive arguments that serve different purposes
size
sets size of the tensor returned from F.interpolate
, and is also used to implicitly calculate the scale_factor
. size
must be an integer. the size
of the output will not change with different input sizes -- but the scale_factor
will change
scale_factor
sets the interpolation scaling which is used to calculate the size of the output. However initially, this output will often be non-integer -- In case of F.interpolate
floor division is used.
i.e.
>>> F.interpolate(torch.rand(1, 1, 10, 10), scale_factor=.99)
torch.Size(1, 1, 9, 9)
output_size
would only be used when scale_factor
is specified and would pad the output of F.interpolate
when the sizes don't match. This doesn't change the scale_factor
output_size
is only necessary when upsampling -- The problem it's solving is best illustrated with a Unet-like skip connection layer.
_FYI: we can't use the size
argument here because we wan't to keep scalefactor constant & handle arbitrarily sized inputs
skip = torch.rand(1, 1, 32, 32) # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3) # --> torch.Size(1, 1, 10, 10)
x = nn.Upsample(scale_factor=1/3).forward(x) # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1) # Throws error due to size mismatch
having the output_size
argument fixes this problem
skip = torch.rand(1, 1, 32, 32) # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3) # --> torch.Size(1, 1, 10, 10)
x = nn.Upsample(scale_factor=1/3).forward(x, output_size=skip.shape) # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1) # Throws error due to size mismatch
To answer your questions:
"However, I think the semantics of the proposed output_size arg may be confusing; for example, it seems to conflict with scale_factor for this example case:"
nn.Upsample(scale_factor=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 30, 30])
This example should throw an error since the only valid output sizes are 10 & 11. The semantics are the same as
ConvTranspose2D
:>> nn.ConvTranspose2d(1, 1, 2, stride=2)(torch.rand(1, 1, 5, 5), output_size=[1, 1, 30, 30]) ValueError: requested an output size of [30, 30], but valid sizes range from [10, 10] to [11, 11] (for an input of torch.Size([5, 5]))
.
"As a workaround, I'd recommend using F.interpolate with a size specified for upsampling rather than nn.Upsample. Does this solve the problem for you?"
As mentioned above -- using size
doesn't allow for a constant scale_factor
with arbitrarily sized input images/tensors
having the
output_size
argument fixes this problemskip = torch.rand(1, 1, 32, 32) # --> torch.Size(1, 1, 32, 32) x = F.interpolate(skip, scale_factor=1/3) # --> torch.Size(1, 1, 10, 10) x = nn.Upsample(scale_factor=1/3).forward(x, output_size=skip.shape) # --> torch.Size(1, 1, 30, 30) x = torch.cat([skip, x], dim=1) # Throws error due to size mismatch
... As mentioned above -- using size doesn't allow for a constant scale_factor with arbitrarily sized input images/tensors
@jenkspt It's unclear to me why you couldn't just do this instead:
skip = torch.rand(1, 1, 32, 32) # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3) # --> torch.Size(1, 1, 10, 10)
x = F.interpolate(x, size=skip.shape[-2:]) # --> torch.Size(1, 1, 32, 32)
x = torch.cat([skip, x], dim=1) # --> torch.Size(1, 2, 32, 32)
Here's one input size:
skip = torch.rand(1, 1, 32, 32) # --> torch.Size(1, 1, 32, 32)
x = F.interpolate(skip, scale_factor=1/3) # --> torch.Size(1, 1, 10, 10)
x = F.interpolate(x, size=skip.shape[-2:]) # --> torch.Size(1, 1, 32, 32)
x = torch.cat([skip, x], dim=1) # --> torch.Size(1, 2, 32, 32)
The scale factor calculated for x = F.interpolate(x, size=skip.shape[-2:])
would be 1 / 3.2
== 0.3125
If we use a different input size:
skip = torch.rand(1, 1, 30, 30) # --> torch.Size(1, 1, 30, 30)
x = F.interpolate(skip, scale_factor=1/3) # --> torch.Size(1, 1, 10, 10)
x = F.interpolate(x, size=skip.shape[-2:]) # --> torch.Size(1, 1, 30, 30)
x = torch.cat([skip, x], dim=1) # --> torch.Size(1, 2, 30, 30)
The scale factor calculated for x = F.interpolate(x, size=skip.shape[-2:])
would be 1 / 3
== 0.3333
The intention here is not to change the scaling factor with different input sizes.
The intention here is not to change the scaling factor with different input sizes.
@jenkspt Are you expecting (zero?) padding around the edges to maintain the scale factor while achieving the correct input size?
Yes I would expect zero padding. If F.interpolate
or nn.Upsample
had a padding_mode
argument I would expect to use that (but they don't).
🚀 The feature, motivation and pitch
Motivated by the same reasons ConvTranspose modules have an
output_size
argument.The problem is that multiple tensor sizes will downsample to the same size For example:
When reversing this operation with an upsampling interpolation operation, an
output_size
argument is required to reconstruct the original input size. The desired functionality is as follows -- to reconstruct the original size of [1, 1, 11, 11], we would do:Implementation:
_ConvTransposeNd
already has a utility method_output_padding
for calculating the padding of the output tensor when theoutput_size
argument is provided. WhileUpsample
doesn't require most of the arguments (e.g. padding, dilation, kernel_size ...) It seems like a reasonable solution would be to extract the_output_padding
method into a separate utility function to use with theUpsample
module as well.I'm happy to create a pull request if this is a desired feature.
Alternatives
No response
Additional context
The motivation for this feature request is to get Unet models (specifically segmentation_models_pytorch Unet model) to work with arbitrarily sized input images. Currently these models throw errors at the upsampling op for certain input images sizes -- due to the round-off problem mentioned above. However this feature is useful for any model that include downsampling + upsampling interpolation layers.
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345