FabianFuchsML / se3-transformer-public

code for the SE3 Transformers paper: https://arxiv.org/abs/2006.10503
475 stars 69 forks source link

Update October 2021

Check out this work by Alexandre Milesi et al. from Nvidia. They managed to speed up training of the SE(3)-Transformer by up to 21(!) times and reduced memory consumption by up to 43 times. Code here.

SE(3)-Transformers

This repository is the official implementation of SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks.

Please cite us as

@inproceedings{fuchs2020se3transformers,
    title={SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks},
    author={Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year={2020},
    booktitle = {Advances in Neural Information Processing Systems 34 (NeurIPS)},
}

Prerequisites

update 2020/12/22: we made some updates to support dgl >= dgl0.5.0; it is tested with the combinations dgl0.5.3 & torch 1.7 and dgl0.4.3.post2 & torch 1.4 We recommend using dgl 0.5.3 & torch 1.7

Check requirements.txt for other dependencies

Experiments

The code for experiments specific is meant to be placed in the folder experiments.

We provide an implementation for the n-body experiment, the QM9 experiments. In addition, we provide a multi-step toy-optimisation experiment losely inspired the by protein structure prediction task.

Basic usage

The SE(3)-transformer is built on top of the DGL in Pytorch.

###
# Define a toy model: more complex models in experiments/qm9/models.py
###

# The maximum feature type is harmonic degree 3
num_degrees = 4

# The Fiber() object is a representation of the structure of the activations.
# Its first argument is the number of degrees (0-based), so num_degrees=4 leads
# to feature types 0,1,2,3. The second argument is the number of channels (aka
# multiplicities) for each degree. It is possible to have a varying number of
# channels/multiplicities per feature type and to use arbitrary feature types, 
# for this functionality check out fibers.py.

fiber_in = Fiber(1, num_features)
fiber_mid = Fiber(num_degrees, 32)
fiber_out = Fiber(1, 128)

# We build a module from:
# 1) a multihead attention block
# 2) a nonlinearity
# 3) a TFN layer (no attention)
# 4) graph max pooling
# 5) a fully connected layer -> 1 output

model = nn.ModuleList([GSE3Res(fiber_in, fiber_mid),
                       GNormSE3(fiber_mid),
                       GConvSE3(fiber_mid, fiber_out, self_interaction=True),
                       GMaxPooling()])
fc_layer = nn.Linear(128, 1)

###
# Run model: complete train script in experiments/qm9/run.py
###

# Before each forward pass we make a call to get_basis_and_r, which computes
# the equivariant weight basis and relative positions of all the nodes in the
# graph. Pass these variables as keyword arguments to SE(3)-transformer layers.

basis, r = get_basis_and_r(G, num_degrees-1)

# Run SE(3)-transformer layers: the activations are passed around as a dict,
# the key given as the feature type (an integer in string form) and the value
# represented as a Pytorch tensor in the DGL node feature representation.

features = {'0': G.ndata['my_features']}
for layer in model:
    features = layer(features, G=G, r=r, basis=basis)

# Run non-DGL layers: we can do this because GMaxPooling has converted features
# from the DGL node feature representation to the standard Pytorch tensor rep.
output = fc_layer(features)

FAQ

Type issues with QM9 experiments

One user reported that they experienced issues with data types when running the QM9 experiments. For them, adding the following lines just before line 184 of qm9.py fixed the issue:

x=x.astype(np.float32)
one_hot=one_hot.astype(np.float32)
atomic_numbers=atomic_numbers.astype(np.float32)

Speed

Here are some ideas about speeding up the SE3-Transformer:

Credit to '3D Steerable CNNs'

The code in the subfolder equivariant_attention/from_se3cnn is strongly based on https://github.com/mariogeiger/se3cnn which accompanies the paper '3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data' by Weiler et al.

Feedback & Questions

Please contact us at: fabian @ robots . ox . ac . uk

License

MIT License

Copyright (c) 2020 Fabian Fuchs and Daniel Worrall

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.