atomicarchitects / equiformer_v2

[ICLR 2024] EquiformerV2: Improved Equivariant Transformer for Scaling to Higher-Degree Representations
https://arxiv.org/abs/2306.12059
MIT License
218 stars 27 forks source link

Small equivariant example #5

Open DeNeutoy opened 1 year ago

DeNeutoy commented 1 year ago

Hi @yilunliao,

Thanks for the nice codebase - I am adapting it for another purpose, and I was running into some issues when checking the outputs are actually equivariant. Are there any init flags that must be set in a certain way to guarantee equivariance?

I have a snippet equivalent to this:

import torch_geometric
import torch
from e3nn import o3
from torch_geometric.data import Data
from nets.equiformer_v2.equiformer_v2_oc20 import EquiformerV2_OC20

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
pos = torch.randn(10, 3)
data = Data(pos=pos, edge_index=edge_index)

R = torch.tensor(o3.rand_matrix())

model = EquiformerV2_OC20(
        num_layers=2,
        attn_hidden_channels=16,
        ffn_hidden_channels=16,
        sphere_channels=16,
        edge_channels=16,
        alpha_drop=0.0, # Turn off dropout for eq
        drop_path_rate=0.0, # Turn off drop path for eq
    )

energy1, forces1 = model(data)
rotated_pos = torch.matmul(pos, R)
data.pos = rotated_pos
energy2, forces2 = model(data)

assert energy1 == energy2
assert torch.allclose(forces1, torch.matmul(forces2, R), atol=1.0e-3)

and the energies are equal, but the forces do not obey equality under rotation. I've turned off all dropout and set the model to eval - just wondering if there are any other tricks to retain the genuine eq behaviour. Thanks!

yilunliao commented 1 year ago

Hi @DeNeutoy

Can you let me know how large the difference is? After scanning over the code, can you also make sure you should compare this:

 assert torch.allclose(torch.matmul(forces1, R), forces2, atol=1.0e-3)

not this:

 assert torch.allclose(forces1, torch.matmul(forces2, R), atol=1.0e-3)

I think it is because force2 has input positions rotated, and therefore you have to rotate force2 back or rotate the output force1 before comparing the two force outputs.

Best

DeNeutoy commented 1 year ago

Thanks for the response @yilunliao - sorry, I was rotating the forces1 and comparing as you suggested - this was a bug in my snippet, but not my actual code. I dug into this a little and when I look at the nodes which are not correct, I see this:

print((forces1 == forces2).all(-1).all(-1))
tensor([False, False, False,  True,  True,  True,  True,  True,  True,  True])

indicating that it is only nodes which have edges that are affected. I then confirmed this by modifying the input graph, and only the receivers changes this, e.g:

senders = torch.tensor([0, 1, 2, 1, 2, 0])
receivers = torch.tensor([1, 0, 1, 2, 0, 4])
# Rerun, get:
tensor([False, False, False,  True, **False**,  True,  True,  True,  True,  True])

but changing the senders doesn't:

senders = torch.tensor([0, 1, 2, 1, 2, 4])
receivers = torch.tensor([1, 0, 1, 2, 0, 0])
# Rerun, get:
tensor([False, False, False,  True, True,  True,  True,  True,  True,  True])

I was wondering if this might alert you to something? I then started stepping through the code, and the embeddings are equal up until the edge degree embedding, but if I remove this, they are then unequal again after the TransBlockV2 stack.

For the node indices which don't match, the absolute difference is large:

print(torch.abs(forces1, forces2)[:3, :, :].mean())
1.0051

This is with a completely untrained model, although I wouldn't expect that to make a difference.

Any help is much appreciated!

yilunliao commented 1 year ago

@DeNeutoy

I see.

Can you make sure the results of edge-degree embeddings satisfy equivariance constraint? Since this embedding only uses rotation and applies linear layers to m = 0 components, this is strictly equivariant and can be easily tested.

I have limited bandwidth until the weekend or next week but will look at this and provide another example. (Ping me if you have not heard from me)

yilunliao commented 1 year ago

@DeNeutoy

Sorry for the late reply.

Have you figured out the issue? If no, can you please update me on what you think is the problem?

I have an incoming deadline, so I would be late to response, but I will make sure we can find the reason.

Best

DeNeutoy commented 1 year ago

Hi @yilunliao ,

I haven't, unfortunately. I tried looking into the edge degree embeddings, but it's not as simple as looking at a rotation of the input vectors - the edge_embedding outputs a SO3_embedding object, which internally has a _rotate method which is defined by the SO3_rotations + wigner matrices defined by the model's forward pass. So it was kind of unclear to me how to "unrotate" the embeddings.

If you had a small example, that would be helpful - but I understand if this is difficult to produce. These things are quite complex!

yilunliao commented 1 year ago

Hi @DeNeutoy .

Here is how we rotate the embedding back to the original coordinate after SO(2) linear layers: https://github.com/atomicarchitects/equiformer_v2/blob/main/nets/equiformer_v2/so3.py#L452

Sure. I can provide a simple example to test that, but I will do that next weekend due to an incoming deadline.

Best

BurgerAndreas commented 6 months ago

Hi @yilunliao,

A small example would be extremely useful to build on your codebase! Could you help us out?