QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

Cannot pass weights of R2Conv as a positional argument #51

Closed Branikas closed 2 years ago

Branikas commented 2 years ago

Hi there. Many thanks for this amazing library and the whole equivariant framework, it's a very inspiring work!

I think the title of the issue pretty much sums it all up. I am trying to pass 2 positional arguments as input to an R2Conv module, both the input data and also the weight parameters.

My issue is that I get an error saying that 3 positional arguments were given when 2 were expected. The standard conv2D module of PyTorch allows this, so I was wondering if it something wrong in my implementation or in general this behavior cannot be used for an R2Conv module. Many thanks in advance,

Gabri95 commented 2 years ago

Hi @Branikas

As far as I know, PyTorch's torch.nn.Conv2d also only accepts the input tensor in its forward() method (see here).

Is it possible that what you mean is torch.nn.functional.conv2d? That, however, is not a torch.nn.Module, but a simple function. We do not have an equivalent function inside our library since R2Conv needs to first expand the weights to construct the convolutional filter. However, this operation requires some additional information, which is generated inside the R2Conv.__init__ and which is stored inside R2Conv.basisexpansion.

Hope this helps

Gabriele

Branikas commented 2 years ago

Hi Gabrielle,

Many thanks for the quick response. Yes, it actually makes sense. Thanks a lot for this. Feel free to close the issue now.

Regards, Stathis