FabianFuchsML / se3-transformer-public

code for the SE3 Transformers paper: https://arxiv.org/abs/2006.10503
486 stars 67 forks source link

rotation equivariance for type 1 feature #3

Closed Chen-Cai-OSU closed 3 years ago

Chen-Cai-OSU commented 3 years ago

Hello,

I am playing around the following script to check if the output rotates accordingly when I rotate the input. I am only using d, and f (all ones). As I rotate d by a rotation matrix, it doesn't seem that the output rotates accordingly.

I spent some time checking the feature types but it seems to me correct. The script is modified from your README example. Would you like to take a quick look at why this is the case? Thank you very much!

import dgl
import torch
from copy import deepcopy

from equivariant_attention.from_se3cnn.SO3 import rot

torch.manual_seed(0)
import torch.nn as nn
from torch.utils.data import DataLoader

from equivariant_attention.fibers import Fiber
from equivariant_attention.modules import GSE3Res, GNormSE3, GConvSE3, GMaxPooling, get_basis_and_r
# The maximum feature type is harmonic degree 3
from experiments.qm9.QM9 import QM9Dataset
from dgl.nn.pytorch.glob import AvgPooling

class GAvgVecPooling(nn.Module):
    """Graph Average Pooling module."""

    def __init__(self):
        super().__init__()
        self.pool = AvgPooling()

    def forward(self, features, G, **kwargs):
        print(f'before pool: {summary(features["1"])}')
        h_vec = []
        for i in range(3):
            h = features['1'][..., i]
            # print(f'before pool: {summary(h)}')
            h_vec.append(self.pool(G, h))
        return torch.cat(h_vec, axis=1)

def build_model():
    # The Fiber() object is a representation of the structure of the activations.
    # Its first argument is the number of degrees (0-based), so num_degrees=4 leads
    # to feature types 0,1,2,3. The second argument is the number of channels (aka
    # multiplicities) for each degree. It is possible to have a varying number of
    # channels/multiplicities per feature type and to use arbitrary feature types,
    # for this functionality check out fibers.py.
    num_degrees = 2
    num_features = 16  # todo added by Chen
    fiber_in = Fiber(1, num_features)
    fiber_mid = Fiber(num_degrees, 16)
    fiber_out = Fiber(2, 128)

    # We build a module from:
    # 1) a multihead attention block
    # 2) a nonlinearity
    # 3) a TFN layer (no attention)
    # 4) graph max pooling
    # 5) a fully connected layer -> 1 output

    model = nn.ModuleList([GSE3Res(fiber_in, fiber_mid),
                           # GNormSE3(fiber_mid),
                           # GConvSE3(fiber_mid, fiber_out, self_interaction=True),
                           # GConvSE3(fiber_out, Fiber(2, 1, structure=[(1, 1)]), self_interaction=True),
                           # GAvgVecPooling()
                           GConvSE3(fiber_mid, Fiber(2, 1, structure=[(1, 1)]), self_interaction=True),
                           ])
    fc_layer = nn.Linear(128, 1)
    return model

def collate(samples):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(y)

def summary(features):
    if isinstance(features, dict):
        for k, v in features.items():
            print(f'type: {k}; Size: {v.size()}')
    else:
        print(f'Size: {features.size()}')

def set_feat(G, R, num_features=16):
    G.edata['d'] = G.edata['d'] @ R
    G.edata['w'] = torch.rand((G.edata['w'].size(0), 0))
    G.ndata['x'] = torch.rand((G.ndata['x'].size(0), 0))
    G.ndata['f'] = torch.ones((G.ndata['f'].size(0), num_features, 1))
    print(G)

    # Run SE(3)-transformer layers: the activations are passed around as a dict,
    # the key given as the feature type (an integer in string form) and the value
    # represented as a Pytorch tensor in the DGL node feature representation.

    features = {'0': G.ndata['f']}
    return G, features

def apply_model(model, G, features, num_degrees=2):
    basis, r = get_basis_and_r(G, num_degrees - 1)
    for i, layer in enumerate(model):
        print(f'feat before {layer}')
        summary(features)
        features = layer(features, G=G, r=r, basis=basis)
        print(f'feat after {layer}')
        summary(features)
        print('-' * 100)
        # print(i, features)
    # print(features)
    return features['1'][:, 0, :]
    # Run non-DGL layers: we can do this because GMaxPooling has converted features
    # from the DGL node feature representation to the standard Pytorch tensor rep.
    # print(features)
    # output = fc_layer(features)
    # print(output.size())

if __name__ == '__main__':
    dataset = QM9Dataset('./QM9_data.pt', "homo", mode='train', fully_connected=True)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate)
    for G, y in dataloader:
        break
    G1 = deepcopy(G)
    G2 = deepcopy(G)
    model = build_model()

    R1 = rot(0, 0, 0)
    G1, features = set_feat(G1, R1)
    out1 = apply_model(model, G1, features)
    summary(out1)

    R2 = rot(10, 0, 30)
    G2, features = set_feat(G2, R2)
    out2 = apply_model(model, G2, features)
    summary(out2)

    print(torch.max(out2 - out1 @ R2)) # 0.0469, which is not zero. WHY?
Chen-Cai-OSU commented 3 years ago

I picked one point x from the complete graph and tried different rotations (rotate around z axis) and plot the two trajectories: 1) apply rotations to x first and then apply a simple se3 model (consisting of only GConvSE3 layer, marked in yellow) and 2) applying se3 model to x and then rotate the output (blue points)

Somehow the yellow circle overlaps with blue one when viewed from the top of the z axis but they are not exactly the same, which is what confuses me. The equivariance error should be rather small as mentioned in the paper, so I don't know why two trajectories still differ.

Screen Shot 2020-10-28 at 11 23 11 AM
FabianFuchsML commented 3 years ago

Thank you for pointing this out! I fixed a bug in the repository and ran your test script again. I now got an error of 2.6822e-07 (same order of magnitude as what we reported in table 1). Feel free to check out the latest commit and see whether you still get anything unexpected.

Chen-Cai-OSU commented 3 years ago

Thank you for your timely response! I confirm that on my computer the error is also roughly 1e-7. Let me leave this issue open for a few days in case I encounter some other issues.

Chen-Cai-OSU commented 3 years ago

Hi,

I found that although for most graphs the equivariance is rather small, sometimes the equivariance can also be large. I print out the error when it's larger than 1e-5. Here are small numbers

0 0.09846335649490356
9 0.04315918684005737
38 0.012460410594940186
40 0.00337403267621994
60 0.07560521364212036
65 0.0023030638694763184
83 0.0005896091461181641
87 0.08681213855743408
91 0.022571653127670288
97 0.007641196250915527
102 0.03957938402891159
103 0.11112213134765625
107 0.043515644967556
118 0.042784884572029114
Chen-Cai-OSU commented 3 years ago

Here is the code to reproduce the result. I was wondering is this expected?

import dgl
import torch
from copy import deepcopy

from equivariant_attention.from_se3cnn.SO3 import rot

torch.manual_seed(0)
import torch.nn as nn
from torch.utils.data import DataLoader

from equivariant_attention.fibers import Fiber
from equivariant_attention.modules import GSE3Res, GNormSE3, GConvSE3, GMaxPooling, get_basis_and_r
from experiments.qm9.QM9 import QM9Dataset
from dgl.nn.pytorch.glob import AvgPooling

class GAvgVecPooling(nn.Module):
    """Graph Average Pooling module."""

    def __init__(self):
        super().__init__()
        self.pool = AvgPooling()

    def forward(self, features, G, **kwargs):
        print(f'before pool: {summary(features["1"])}')
        h_vec = []
        for i in range(3):
            h = features['1'][..., i]
            # print(f'before pool: {summary(h)}')
            h_vec.append(self.pool(G, h))
        return torch.cat(h_vec, axis=1)

def build_model():
    # The Fiber() object is a representation of the structure of the activations.
    # Its first argument is the number of degrees (0-based), so num_degrees=4 leads
    # to feature types 0,1,2,3. The second argument is the number of channels (aka
    # multiplicities) for each degree. It is possible to have a varying number of
    # channels/multiplicities per feature type and to use arbitrary feature types,
    # for this functionality check out fibers.py.
    num_degrees = 2
    num_features = 16  # todo added by Chen
    fiber_in = Fiber(1, num_features)
    fiber_mid = Fiber(num_degrees, 16)
    fiber_out = Fiber(2, 128)

    # We build a module from:
    # 1) a multihead attention block
    # 2) a nonlinearity
    # 3) a TFN layer (no attention)
    # 4) graph max pooling
    # 5) a fully connected layer -> 1 output
    model = nn.ModuleList([GSE3Res(fiber_in, fiber_mid),
                           # GNormSE3(fiber_mid),
                           # GConvSE3(fiber_mid, fiber_out, self_interaction=True),
                           # GConvSE3(fiber_out, Fiber(2, 1, structure=[(1, 1)]), self_interaction=True),
                           # GAvgVecPooling()
                           GConvSE3(fiber_mid, Fiber(2, 1, structure=[(1, 1)]), self_interaction=True),
                           ])
    fc_layer = nn.Linear(128, 1)
    return model

def collate(samples):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(y)

def summary(features):
    if isinstance(features, dict):
        for k, v in features.items():
            print(f'type: {k}; Size: {v.size()}')
    else:
        print(f'Size: {features.size()}')

def set_feat(G, R, num_features=16):
    G.edata['d'] = G.edata['d'] @ R
    G.edata['w'] = torch.rand((G.edata['w'].size(0), 0))
    G.ndata['x'] = torch.rand((G.ndata['x'].size(0), 0))
    G.ndata['f'] = torch.ones((G.ndata['f'].size(0), num_features, 1))
    # print(G)

    # Run SE(3)-transformer layers: the activations are passed around as a dict,
    # the key given as the feature type (an integer in string form) and the value
    # represented as a Pytorch tensor in the DGL node feature representation.

    features = {'0': G.ndata['f']}
    return G, features

def apply_model(model, G, features, num_degrees=2):
    basis, r = get_basis_and_r(G, num_degrees - 1)
    for i, layer in enumerate(model):
        # print(f'feat before {layer}')
        # summary(features)
        features = layer(features, G=G, r=r, basis=basis)
        # print(f'feat after {layer}')
        # summary(features)
        # print('-' * 100)
        # print(i, features)
    # print(features)
    return features['1'][:, 0, :]
    # Run non-DGL layers: we can do this because GMaxPooling has converted features
    # from the DGL node feature representation to the standard Pytorch tensor rep.
    # print(features)
    # output = fc_layer(features)
    # print(output.size())

if __name__ == '__main__':
    dataset = QM9Dataset('./QM9_data.pt', "homo", mode='train', fully_connected=True)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate)
    model = build_model()
    for i, (G, y) in enumerate(dataloader):

        G1 = deepcopy(G)
        G2 = deepcopy(G)

        R1 = rot(0, 0, 0)
        G1, features = set_feat(G1, R1)
        out1 = apply_model(model, G1, features)
        # summary(out1)

        R2 = rot(10, 0, 30)
        G2, features = set_feat(G2, R2)
        out2 = apply_model(model, G2, features)
        # summary(out2)
        diff = torch.max(out2 - out1 @ R2).item()
        if diff>1e-5:
            print(i, diff) 
FabianFuchsML commented 3 years ago

I just pushed a commit that should fix this. What happened is that there was a bug in how the fully connected graph was built. This led to 'self-references' in some cases, i.e. a point attending to itself with a relative position vector of 0. If one then evaluates the spherical harmonics naively at (0,0,0), then equivariance is broken (because rotating (0,0,0) is still (0,0,0)). In the paper, we didn't use the fully connected version because it gave worse results than the sparse one. It would certainly be interesting to try again with the bugged being fix. Good find!

Chen-Cai-OSU commented 3 years ago

Wonderful! Thank you for the quick fix!

Chen-Cai-OSU commented 3 years ago
# Created at 2020-11-25
# Summary: small norm after pooling

from copy import deepcopy

import dgl
import torch

from equivariant_attention.from_se3cnn.SO3 import rot

torch.manual_seed(0)
import torch.nn as nn
from torch.utils.data import DataLoader

from equivariant_attention.fibers import Fiber
from equivariant_attention.modules import GConvSE3, get_basis_and_r
from experiments.qm9.QM9 import QM9Dataset
from dgl.nn.pytorch.glob import AvgPooling

class GAvgVecPooling_(nn.Module):
    """Graph Average Pooling module."""

    def __init__(self):
        super().__init__()
        self.pool = AvgPooling()

    def forward(self, features, G, **kwargs):
        h_vec = []
        print(f'before pooling norm: {torch.norm(features["1"])}')
        for i in range(3):
            h = features['1'][..., i]
            h_vec.append(self.pool(G, h))
        ret = torch.cat(h_vec, axis=1)
        print(f'after pooling norm: {torch.norm(ret)}')
        return ret

def build_model(num_features=32, mid_dim=125):
    fiber_in = Fiber(1, num_features)
    new_model = nn.ModuleList([GConvSE3(fiber_in, Fiber(2, 1, structure=[(1, 1)]), self_interaction=True),
                               GAvgVecPooling_()
                               ])
    print(new_model)
    return new_model

def collate(samples):
    graphs, y = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(y)

def summary(features):
    if isinstance(features, dict):
        for k, v in features.items():
            print(f'type: {k}; Size: {v.size()}')
    else:
        print(f'Size: {features.size()}')

def set_feat(G, R, num_features=32):
    G.edata['d'] = G.edata['d'] @ R
    G.edata['w'] = torch.rand((G.edata['w'].size(0), 0))
    G.ndata['x'] = torch.rand((G.ndata['x'].size(0), 0))
    G.ndata['f'] = torch.ones((G.ndata['f'].size(0), num_features, 1))
    features = {'0': G.ndata['f']}
    return G, features

# @profile
def apply_model(model, G, features, num_degrees=2):
    basis, r = get_basis_and_r(G, num_degrees - 1)
    for i, layer in enumerate(model):
        # print(f'feat before {layer}')
        # summary(features)
        features = layer(features, G=G, r=r, basis=basis)
        # print(f'feat after {layer}')
        # summary(features)
        # print('-' * 100)
        # print(i, features)
    # print(features)
    # ret = features['1'][:, 0, :]
    # print(features['1'].size())
    # pool_ret = GAvgVecPooling(features['1'].view(-1, 3), G)
    # del features
    # print(features.size())
    summary(features)
    print(torch.mean(features)) # very small norm, WHY?
    return features  #

if __name__ == '__main__':
    dataset = QM9Dataset('./QM9_data.pt', "homo", mode='train', fully_connected=True)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate)
    model = build_model(num_features=12, mid_dim=123)

    for i, (G, y) in enumerate(dataloader):
        G1 = deepcopy(G)
        G2 = deepcopy(G)

        R1 = rot(0, 0, 0)
        G1, features = set_feat(G1, R1, num_features=12)
        out1 = apply_model(model, G1, features)

        R2 = rot(10, 0, 30)
        G2, features = set_feat(G2, R2, num_features=12)
        out2 = apply_model(model, G2, features)
        diff = torch.max(out2 - out1 @ R2).item()
Chen-Cai-OSU commented 3 years ago

Hi Fabian,

It seems that after pooling, the norm of features will become very small (~10^-8). Do not know why this is the case? Is there something wrong with my mean pooling layer? I tried sum pooling also but it's the same issue. It's a bit strange that when adding up all type 1 features over nodes, they kind of cancel out.

the output looks like the following:

before pooling norm: 3.1610512733459473
after pooling norm: 5.6950220539420116e-08
Size: torch.Size([1, 3])
tensor(-2.2221e-08, grad_fn=<MeanBackward0>)
41 3.519508950944328e-08
before pooling norm: 2.5761053562164307
after pooling norm: 4.124186503418059e-08
Size: torch.Size([1, 3])
tensor(-2.1110e-08, grad_fn=<MeanBackward0>)
before pooling norm: 2.5761055946350098
after pooling norm: 8.165373799329245e-08
Size: torch.Size([1, 3])
tensor(-3.7157e-08, grad_fn=<MeanBackward0>)
FabianFuchsML commented 3 years ago

Is there something wrong with my mean pooling layer?

Your implementation of the pooling layer looks correct. I would assume that the (unwanted) symmetry comes from something that happens beforehand.