QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

About the equivalence of wide_resnet #56

Closed hcy226 closed 2 years ago

hcy226 commented 2 years ago

Dear Author: I change the wide_resnet by deleting the pooling and linear operation after out=out.tensor:

   def forward(self, x):

        # wrap the input tensor in a GeometricTensor
        x = enn.GeometricTensor(x, self.in_type)

        out = self.conv1(x)
        out = self.layer1(out)

        out = self.layer2(self.restrict1(out))

        out = self.layer3(self.restrict2(out))

        out = self.bn(out)
        out = self.relu(out)

        # extract the tensor from the GeometricTensor to use the common Pytorch operations
        out = out.tensor

        # b, c, w, h = out.shape
        # out = F.avg_pool2d(out, (w, h))

        # out = out.view(out.size(0), -1)
        # out = self.linear(out)
        return out

But find that the equivalence does not exist anymore:

    print()
    print('TESTING INVARIANCE:                    ')
    print('REFLECTIONS along the VERTICAL axis:   ' + (str(torch.norm(y-y_fv))))
    print('REFLECTIONS along the HORIZONTAL axis: ' + (str(torch.norm(y-y_fh))))
    print('90 degrees ROTATIONS:                  ' + (str(torch.norm(y-y90))))
    print('REFLECTIONS along the 45 degrees axis: ' + (str(torch.norm(y-y90_fh))))

The output is:

REFLECTIONS along the VERTICAL axis:   tensor(1.9236, grad_fn=<CopyBackwards>)
REFLECTIONS along the HORIZONTAL axis: tensor(2.0343, grad_fn=<CopyBackwards>)
90 degrees ROTATIONS:                  tensor(2.0621, grad_fn=<CopyBackwards>)
REFLECTIONS along the 45 degrees axis: tensor(2.0622, grad_fn=<CopyBackwards>)

May I have a ask about the reason?

Gabri95 commented 2 years ago

Hi @hcy226

The reason is that the final avg_pool2d is necessary to obtain invariance to rotations. By removing the average pooling, the output still has some spatial resolution; a rotation of the input leads to a rotation of this output feature, i.e. the model is equivariant, not invariant.

The code, however, is testing for invariance, not equivariance. This means that, when you compare y with e.g. y90, you should first rotate y90 back by 90 degrees.

Hope this helps, Gabriele

hcy226 commented 2 years ago

Thanks @Gabri95 ! Here comes another question: why is the shape of x in the wrn 2^n+1,e.g. 513*513/33*33 in case I use pooling or dilation? Can I test my e2cnn-based model by the input shape of 256*256 or512*512 ?

Gabri95 commented 2 years ago

Hi @hcy226

This is to ensure the pooling with stride 2 doesn't break equivariance to 90 degrees rotations. Check Figure 2 here https://arxiv.org/abs/2004.09691

Best, Gabriele

hcy226 commented 2 years ago

Hi @hcy226

This is to ensure the pooling with stride 2 doesn't break equivariance to 90 degrees rotations. Check Figure 2 here https://arxiv.org/abs/2004.09691

Best, Gabriele

Thanks! In fact, when I test the equivariance of my own network based on e2cnn with the input shape [1,3,256,256], I find that average pooling and max pooling in e2cnn does not break the equivariance, but conv with stride>1 will do, which is different from the paper. May I know the reason why is that?

Gabri95 commented 2 years ago

Hi @hcy226

I am not sure what you mean: stride>1 is expected to break equivariance when the input has even size, as explained in the paper I linked to. Could you be more precise?

Best, Gabriele