QUVA-Lab / e2cnn

E(2)-Equivariant CNNs Library for Pytorch
https://quva-lab.github.io/e2cnn/
Other
596 stars 75 forks source link

`SO(2)`-equivariant network for ImageNet #37

Closed kristian-georgiev closed 3 years ago

kristian-georgiev commented 3 years ago

Hi Gabriele, thanks again for the great library!

I apologize in advance if the answer to this can be found in the docs/paper/thesis, I couldn't find it there.

My question is rather soft: what is a "reasonable" SO(2)-equivariant architecture? E.g. what would be the "proper" way to generalize the Wide ResNets in examples/e2wrn.py to be (approximately) equivariant wrt SO(2) rather than CN?

My attempt so far is:

This gist contains my attempt to make a SO(2)-invariant ResNet-18. An equivariance test, however, shows that I am rather off. In particular, the output of

model.cuda()
model.eval()
img_size = 221
repeats = 500
rot_diffs = []
rand_diffs = []

for _ in range(repeats):
    x = torch.randn([1, 3, img_size, img_size]).cuda()
    xrot = torch.rot90(x, k=1, dims=[2, 3])
    xrand = torch.randn([1, 3, img_size, img_size]).cuda()

    with torch.no_grad():
        z, lat = model(x, with_latent=True)
        zrot, latrot = model(xrot, with_latent=True)
        zrand, latrand = model(xrand, with_latent=True)

    rot_diffs.append(torch.norm(z - zrot))
    rand_diffs.append(torch.norm(z - zrand))

print(f'l2 diff between logits of image and its 90-degree rotation: {torch.mean(torch.stack(rot_diffs)).item():.3f}')
print(f'l2 diff between logits of two arbitrary images: {torch.mean(torch.stack(rand_diffs)).item():.3f}')

is

l2 diff between logits of image and its 90-degree rotation: 0.617
l2 diff between logits of two arbitrary images: 0.739

What is going wrong here?

In case this is helpful, this is the model "topology":

MODEL TOPOLOGY:
    0 - 
    1 - conv1
    2 - conv1._basisexpansion
    3 - conv1._basisexpansion.block_expansion_('irrep_0', 'irrep_1')
    4 - conv1._basisexpansion.block_expansion_('irrep_0', 'irrep_2')
    5 - conv1._basisexpansion.block_expansion_('irrep_0', 'irrep_3')
    6 - relu1
    7 - bn1
    8 - maxpool
    9 - layer1
    10 - layer1.0
    11 - layer1.0.conv1
    12 - layer1.0.conv1._basisexpansion
    13 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_1', 'irrep_1')
    14 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_1', 'irrep_2')
    15 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_1', 'irrep_3')
    16 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_2', 'irrep_1')
    17 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_2', 'irrep_2')
    18 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_2', 'irrep_3')
    19 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_3', 'irrep_1')
    20 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_3', 'irrep_2')
    21 - layer1.0.conv1._basisexpansion.block_expansion_('irrep_3', 'irrep_3')
    22 - layer1.0.bn1
    23 - layer1.0.conv2
    24 - layer1.0.conv2._basisexpansion
    25 - layer1.0.bn2
    26 - layer1.1
    27 - layer1.1.conv1
    28 - layer1.1.conv1._basisexpansion
    29 - layer1.1.bn1
    30 - layer1.1.conv2
    31 - layer1.1.conv2._basisexpansion
    32 - layer1.1.bn2
    33 - layer2
    34 - layer2.0
    35 - layer2.0.conv1
    36 - layer2.0.conv1._basisexpansion
    37 - layer2.0.bn1
    38 - layer2.0.conv2
    39 - layer2.0.conv2._basisexpansion
    40 - layer2.0.bn2
    41 - layer2.0.shortcut
    42 - layer2.0.shortcut.0
    43 - layer2.0.shortcut.0._basisexpansion
    44 - layer2.0.shortcut.0._basisexpansion.block_expansion_('irrep_1', 'irrep_1')
    45 - layer2.0.shortcut.0._basisexpansion.block_expansion_('irrep_2', 'irrep_2')
    46 - layer2.0.shortcut.0._basisexpansion.block_expansion_('irrep_3', 'irrep_3')
        ...
        (layers 2.1 to 4.0 only change multiplicities in field types, no 'block_expansion_'s here) 
        ...
    92 - layer4.1.conv2._basisexpansion.block_expansion_('irrep_1', 'irrep_0')
    93 - layer4.1.conv2._basisexpansion.block_expansion_('irrep_2', 'irrep_0')
    94 - layer4.1.conv2._basisexpansion.block_expansion_('irrep_3', 'irrep_0')
    95 - layer4.1.relu2
    96 - layer4.1.bn2
    97 - layer4.1.bn2.batch_norm_[1]
    98 - linear
    99 - avgpool

Thanks in advance!

Gabri95 commented 3 years ago

Hi @kristian-georgiev https://github.com/QUVA-Lab/e2cnn/issues/36

I tried running you network and this problem seems related to your other issue: https://github.com/QUVA-Lab/e2cnn/issues/36

If you try to feed inputs of shape 191x191, the equivariance error of your model becomes 0 If you want to work on higher resolution images, I would recommend adapting the strides and/or the number of convolution layers in the model. Otherwise, you could also try to downsample your images to shape 191x191 before feeding them in the model.

Hope this helps!

Best, Gabriele

kristian-georgiev commented 3 years ago

Thank you for the quick response and apologies for the delay. I agree with what you said, the lack of equivariance indeed seems to stem from pooling and dilation. In addition, it seems like I've made other poor design choices since I was not able to train the networks to a reasonable accuracy on ImageNet. I've since taken a closer look at your experiments (Table 5.1 and e2cnn_experiments/experiments/models/exp_e2sfcnn.py in particular) and have made some corrections to my architecture (gist is updated). However, the current architecture seems to be unstable (loss consistently becomes NaN around the middle of the first epoch on ImageNet, for both max_frequency=-3 and -5); any pointers on what may be causing this are greatly appreciated. And potentially related, I have a couple of high-level questions:

Thanks again for your time and apologies for the long issue!

Gabri95 commented 3 years ago

Hi @kristian-georgiev

However, the current architecture seems to be unstable (loss consistently becomes NaN around the middle of the first epoch on ImageNet, for both max_frequency=-3 and -5); any pointers on what may be causing this are greatly appreciated.

To answer this, I will need to look a bit more in details in your architecture and probably try to run it myself. Unfortunately, I don't have for it in this moment so, if you don't mind, I'll come back to you about this later.

I will answer the other (very good, btw) questions in the meantime:

I don't understand what is happening in these lines in the gated_normpool layers.

There, we use gated non-linearities only on the non-trivial channels (see line 1247). For trivial channels, we still use ELU. Because a gate for a gated non-linearity is a trivial field, we need to also add an additional trivial field for each non-trivial irrep in input.

what are S and M? It seems that they go unused(?)

You're right, they are there just for documentation. In the comment at line 1290, I use them to describe the total channels. I is the number of fields which require a gate and, therefore, is also the number of gates (which are trivial fields) to be added. S is the size of the features, which is equal to the size of all irreps which require a gate plus 1, the size of the trivial field using ELU. M is the total size: features + #gates, i.e. S + I

Why do we not have the same setup in the hnet_normpool layers (here)?

That is a slightly different architecture. We still apply ELU on each trivial irrep and an independent non-linearity to each non-trivial irrep. Here, however, the independent non-linearity is a norm-relu, not a gated one. While norm-relu is computed directly on the input irrep, the gated non-linearity requires an additional gate. The code for the network using gate-nonlinearity is a bit more complex since I need to account for the additional number of parameters introduced in the model by adding this additional outputs (the gates) in each convolution block.

Where does the expression t /= 16 * s * 2 3 / 4 (from here) come from ... ?

This is just a simple heuristic I created manually to ensure the model has roughly the same number of parameters of the C_16 regular GCNN. I should work for different frequencies of the HNET but only if you compare to the specific C_16 architecture I used in those experiments. I would not trust that formula in another setting.

Is there any benefit/harm in using max_frequency higher than the highest irrep used? E.g. have max_frequency=10 but use irreps of frequency at most 3?

The answer is "it depends" 😅

Rather open-ended: Do you expect the trends from rows 29-44 of Table 5.1 to hold true for higher-resolution harder datasets than MNIST (e.g. ImageNet)?

I think this result is indeed strongly related to the low resolution of the data used. On very high resolution data, this might change. Note, however, that this also refers to the resolution of the field of view of a neuron. So, even if you have very high resolution inputs, the neurons in the first layers will only process small patches. Probably, neurons in the deepest layers of the network can benefit from higher frequencies, but I doubt the first layers can (unless you use very wide filters).

A silly question, but just to double-check: The order of representation in the specification of the field type doesn't matter, correct?

It depends. A permutation of the representations inside a FieldType will result in a fully equivalent architecture. However, if you keep all representations of the same type close to each other in the FieldType, you will obtain much better inference time. This is because the code can access the parts of the input tensor associated with the same representation by using slicing rather than indexing. While slicing only requires a view internally, advanced indexing requires a full copy of the part indexed (see this and this). I just realised this was not really explicit in the documentation, that's my fault. I will update the documentation with some more comments on this.

Let me know if you have any other doubt

Best, Gabriele

Gabri95 commented 3 years ago

Hi @kristian-georgiev ,

I had a quick look at the new code in the gist. I don't see any relevant mistake to be honest.

I don't know if you are still using the same equivariance check you had in the first gist. Back then you were using 191x191 inputs. I guess there has been some change in the architecture, so now you need to use 193x193 inputs to get 0 equivariance error (maybe you changed the first conv layer?).

Do you still have the same issue with the NaN loss? I did not try to train this model to be honest, but I can have a better look if you hae some specific problem.

One note: I see you tried to make you code very general, supporting both SO(2) (with different frequencies) and C_N (for different values of N). I'd recommend implementing two different models, since the kind of operations you use for SO(2) and C_N are generally different. Implementing the two models separately allows for a much simpler code, which is also more readable.

Let me know if you have more questions

Best, Gabriele

ahyunSeo commented 1 year ago

Hello, @kristian-georgiev

I'm sorry to leave a comment on a year-old issue. Did you fix the issue with the NaN loss? (I also got a NaN loss using your gist code)

Best, Ahyun

JoaoGuibs commented 12 months ago

Hi @kristian-georgiev , I fell upon this discussion and was wondering whether you had managed to train the ResNet equivariant model up to a reasonable accuracy on ImageNet? Thank you in advance.

kristian-georgiev commented 12 months ago

Hi @JoaoGuibs and @ahyunSeo. I have not updated my code snippet since my last comment in this thread.