lucidrains / se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
MIT License
257 stars 23 forks source link

SE3Transformer constructor hangs #14

Open mpdprot opened 3 years ago

mpdprot commented 3 years ago

I am trying to run an example from the README. The code is:

import torch
from se3_transformer_pytorch import SE3Transformer

print('Initialising model...')
model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

print('Running model...')
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

The output hangs on 'Initialising model...' and eventually the kernel dies.

Any ideas why this would be happening?

Here is my pip freeze:

anyio==3.2.1
argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613036642480/work
astunparse==1.6.3
async-generator==1.10
attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
axial-positional-embedding==0.2.1
Babel==2.9.1
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
biopython==1.79
bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
certifi==2021.5.30
cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work
chardet==4.0.0
click==8.0.1
configparser==5.0.2
decorator==4.4.2
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
dgl-cu101==0.4.3.post2
dgl-cu110==0.6.1
docker-pycreds==0.4.0
egnn-pytorch==0.2.6
einops==0.3.0
En-transformer==0.3.8
entrypoints==0.3
equivariant-attention @ file:///workspace/projects/se3-transformer-public
filelock==3.0.12
gitdb==4.0.7
GitPython==3.1.18
graph-transformer-pytorch==0.0.1
h5py @ file:///tmp/build/80754af9/h5py_1622088444809/work
huggingface-hub==0.0.12
idna==2.10
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877314848/work
ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
jedi==0.17.0
Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work
joblib==1.0.1
json5==0.9.6
jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
jupyter==1.0.0
jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213308260/work
jupyter-server==1.8.0
jupyter-tensorboard==0.2.0
jupyterlab==3.0.16
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
jupyterlab-server==2.6.0
jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
jupytext==1.11.3
lie-learn @ git+https://github.com/AMLab-Amsterdam/lie_learn@07469085ac0fd4550fd26ff61cb10bb1e92cead1
llvmlite==0.36.0
local-attention==1.4.1
markdown-it-py==1.1.0
MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
mdit-py-plugins==0.2.8
mdtraj==1.9.6
mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
mkl-fft==1.3.0
mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
mkl-service==2.3.0
mp-nerf==0.1.11
nbclassic==0.3.1
nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
networkx==2.5.1
notebook @ file:///tmp/build/80754af9/notebook_1621523661196/work
numba==0.53.1
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
pandas==1.2.4
pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
parso @ file:///tmp/build/80754af9/parso_1617223946239/work
pathtools==0.1.2
performer-pytorch==1.0.11
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
ProDy==2.0
prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
promise==2.3
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
protobuf==3.17.3
psutil==5.8.0
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
py3Dmol==0.9.1
pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
PyYAML==5.4.1
pyzmq==20.0.0
qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
QtPy==1.9.0
regex==2021.4.4
requests==2.25.1
sacremoses==0.0.45
scipy @ file:///tmp/build/80754af9/scipy_1618852618548/work
se3-transformer-pytorch==0.8.10
Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
sentry-sdk==1.1.0
shortuuid==1.0.1
sidechainnet==0.6.0
six @ file:///tmp/build/80754af9/six_1623709665295/work
smmap==4.0.0
sniffio==1.2.0
subprocess32==3.5.4
terminado==0.9.4
testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
tokenizers==0.10.3
toml==0.10.2
torch==1.9.0
tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
tqdm==4.61.1
traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
transformers==4.8.0
typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
urllib3==1.26.5
wandb==0.10.32
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
webencodings==0.5.1
websocket-client==1.1.0
widgetsnbextension==3.5.1
zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work

Here is a summary of my system info (lshw -short):

H/W path    Device  Class      Description
==========================================
                    system     Computer
/0                  bus        Motherboard
/0/0                memory     59GiB System memory
/0/1                processor  Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
/0/100              bridge     440FX - 82441FX PMC [Natoma]
/0/100/1            bridge     82371SB PIIX3 ISA [Natoma/Triton II]
/0/100/1.1          storage    82371SB PIIX3 IDE [Natoma/Triton II]
/0/100/1.3          bridge     82371AB/EB/MB PIIX4 ACPI
/0/100/2            display    GD 5446
/0/100/3            network    Elastic Network Adapter (ENA)
/0/100/1e           display    GK210GL [Tesla K80]
/0/100/1f           generic    Xen Platform Device
/1          eth0    network    Ethernet interface
MattMcPartlon commented 3 years ago

My best guess is that the radial basis functions are taking a while to initialize... That's a huge model. Like >>1T of memory huge.

with 6 heads of dimension 48, and hidden dim 256, the model uses ~48gb with input of size 300. that's with max degree 2... You're looking at about ~100x that with those settings (since higher order types will give order(type)^2 overhead in memory).

would recommend specifying much smaller hidden dimensions for degrees>0, and no more than 256 for type-0 hidden dimension.