tristandeleu / pytorch-meta

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch
https://tristandeleu.github.io/pytorch-meta/
MIT License
1.98k stars 256 forks source link

[bug report] MetaConv2D output shape is not equal to torch.nn.Conv2D output shape when using CIRCULAR padding #138

Closed mfischer-ucl closed 3 years ago

mfischer-ucl commented 3 years ago

Hi everyone, thanks for the great torchmeta-framework. It's really useful!

I believe I've spotted a bug in the torchmeta-class MetaConv2d, specificially for the case of circular padding. The torchmeta output shapes don't match the shapes provided by the corresponding torch function. Below is a MWE to replicate the problem:

from torchmeta.modules import MetaConv2d
from torch.nn import Conv2d
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

regular_conv = Conv2d(6, 12, 3, 1, 1, padding_mode='circular').to(device)
meta_conv = MetaConv2d(6, 12, 3, 1, 1, padding_mode='circular').to(device)

inp = torch.rand((1, 6, 512, 512), device=device)
res1 = regular_conv(inp)
res2 = meta_conv(inp)

print(res1.shape)    # torch.Size([1, 12, 512, 512])
print(res2.shape)    # torch.Size([1, 12, 511, 511])

I've identified the source of the issue to be the handling of the circular padding case in the torchmeta conv.py module, specifically line 36:

return F.conv2d(F.pad(input, expanded_padding, mode='circular').

Exchanging this line with the (blatantly copied) version from torch.nn.Conv2d:

return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),

resolves the issue and makes both outputs have the same shape (in the above MWE: [1, 12, 512, 512]).

I've just tested this with a couple of networks, and so far it seems to work fine. Not sure though whether the entire if padding == 'circular' branch in torchmeta could be interchanged for the simpler pytorch equivalent (line 439). @tristandeleu

tristandeleu commented 3 years ago

Hi, I'm sorry for the late reply. Thank you for reporting this bug!

There have been changes to the way ConvNd modules have worked in the later versions of PyTorch (probably 1.9?). In particular, PyTorch introduced a _conv_forward method that is very convenient for Torchmeta, since we can just use this function, without having to copy the behavior of PyTorch (which may break, as was the case here). I will do the change and update the version of Torchmeta.

tristandeleu commented 3 years ago

This is now fixed in Torchmeta 1.8.0.