QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

checking equivariance for the angles that are not 90n #61

Closed ahyunSeo closed 1 year ago

ahyunSeo commented 1 year ago

Hello,

I want to check equivariance for the angles like 22.5, 45, ... that are not the multiples of 90 degrees. To avoid the interpolation issue, I only tried to compute the values of the image center. However, I could not find a proper way to prove the equivariance. I created a gist for sharing.

In the code, I use C8 group.

for img_size in [181, 183, 185, 187]:
        x = torch.randn([1, 3, img_size, img_size]).cuda()

        x = enn.GeometricTensor(x, in_type) # 3-channel trivial repr.
        xrot1 = x.transform(N-1) # 45 deg CW
        xrot2 = x.transform(N-2) # 90 deg CW

        with torch.no_grad():
            lat = model(x)
            latrot1 = model(xrot1)
            latrot2 = model(xrot2)
            w = lat.shape[-1]
            center = int((w-1)/2)
            latrot1 = latrot1.transform(1) # 45 deg CCW
            latrot2 = latrot2.transform(2) # 90 deg CCW

            print(torch.allclose(lat[0, 0, center, center].tensor, \
                    latrot1[0, 0, center, center].tensor))
            print(torch.allclose(lat[0, 0, center, center].tensor, \
                    latrot2[0, 0, center, center].tensor))

The outputs of the printing are always False, True. I even tried a single 3x3 conv for the 'model' but I never pass the allclose for the 45 degree.

I wonder whether I'm doing something wrong.

Best,

Ahyun

Gabri95 commented 1 year ago

Hi @ahyunSeo

Unfortunately, it is impossible to achieve perfect equivariance to rotations smaller than 90 deg rotations since the images are sampled on a squared grid.

To test equivariance to these rotations, you can include a few tricks:

You can take a look at the MNIST example here. As you can see, I'm first upsampling the mnist digits, then rotating them and, finally, downsampling them, before feeding them into the model. You can also see in block 9 that the outputs can vary by a few percentage points.

If you want to improve the stability of your model, you can also make use of wider filters (e.g. 5x5 filters instead of 3x3) to reduce the discretization artifacts.

Hope this helps! Gabriele

ahyunSeo commented 1 year ago

Hi, Gabriele

Thank you for your useful tricks and comments. I think I can try some of them.

Best, Ahyun