lucidrains / se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
MIT License
262 stars 23 forks source link

CPU/CUDA masking error #2

Closed denjots closed 3 years ago

denjots commented 3 years ago

Hi - nice work - I was just testing out your code and ran into the following error, but only in the backward pass:

RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to _th_masked_scatterbool

When using the nightly Pytorch, the error message is:

RuntimeError: Tensor for argument #2 'mask' is on CPU, but expected it to be on GPU (while checking arguments for maskedscatter)

I'm pretty sure I don't have any tensors in CPU memory, but not sure if this is a bug in your SE3 code or a Pytorch issue. My gut feeling is this is a Pytorch/autograd issue, but I just don't know these particular Pytorch ops well enough to be sure. Tried both 1.7.1 release Pytorch and the latest nightly. Seems like there has been recent work on masked_scatter according to Pytorch issues.

denjots commented 3 years ago

OK, slight update. This error is triggered when the inputs are not contiguous - forcing the input tensors to be contiguous works around the problem. Not sure if this is expected behaviour or not. Feel free to close if this is expected.

lucidrains commented 3 years ago

@denjots oh hi! thanks for reporting this! that's good to know, I'll double check the library for bugs later this weekend

what are you using the equivariant attention for, may i ask?

denjots commented 3 years ago

That's great - thanks. I'm doing a project on predicting drug properties - so quite similar to the QM9 benchmark, hence my interest in the paper by Fabian Fuchs. Then I came by your GitHub page when searching again for his page.

lucidrains commented 3 years ago

@denjots very nice! i was planning on getting QM9 working with this repo! (i decided to not use a graph neural net library, for accessibility reasons)

do let me know how it goes with your project, I can't promise that it is free of bugs, even though the equivariance do seem to be functional given the tests :)

jgreener64 commented 3 years ago

I ran into this issue too, here's a minimal example:

import torch
from se3_transformer_pytorch import SE3Transformer

dev = "cuda"

model = SE3Transformer(
    num_tokens = 4,
    dim = 4,
    num_edge_tokens = 4,
    edge_dim = 4,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 1,
    reduce_dim_out = True
).to(dev)

atoms = torch.randint(0, 4, (2, 32), device=dev)
bonds = torch.randint(0, 4, (2, 32, 32), device=dev)
coors = torch.randn(2, 32, 3, device=dev)
mask  = torch.ones(2, 32, device=dev).bool()

pred = model(atoms, coors, mask = mask, edges = bonds, return_type = 0)

loss = pred.sum()
loss.backward()

When dev = "cpu" this works but when dev = "cuda" it gives the following:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~\Documents\work\screening\se3_test\temp.py in <module>
     24
     25 loss = pred.sum()
---> 26 loss.backward()

~\miniconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
    183                 products. Defaults to ``False``.
    184         """
--> 185         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    186
    187     def register_hook(self, hook):

~\miniconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
--> 127         allow_unreachable=True)  # allow_unreachable flag
    128
    129

RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to _th_masked_scatter_bool_
Exception raised from checked_dense_tensor_unwrap at ..\aten\src\ATen/Utils.h:39 (most recent call first):
00007FFE464B75A200007FFE464B7540 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFDCA60C11000007FFDCA57E0A0 torch_cuda.dll!at::native::set_storage_cuda_ [<unknown file> @ <unknown line number>]
00007FFDCA5F81A200007FFDCA57E0A0 torch_cuda.dll!at::native::set_storage_cuda_ [<unknown file> @ <unknown line number>]
00007FFDCA57D9E700007FFDCA57D810 torch_cuda.dll!at::native::masked_scatter__cuda [<unknown file> @ <unknown line number>]
00007FFDCA5DE98C00007FFDCA57E0A0 torch_cuda.dll!at::native::set_storage_cuda_ [<unknown file> @ <unknown line number>]
00007FFDC39E3A5800007FFDC393E010 torch_cpu.dll!torch::autograd::GraphRoot::apply [<unknown file> @ <unknown line number>]
00007FFDC389C3A900007FFDC389C1A0 torch_cpu.dll!torch::autograd::generated::MaskedSelectBackward::apply [<unknown file> @ <unknown line number>]
00007FFDC3877E9100007FFDC3877B50 torch_cpu.dll!torch::autograd::Node::operator() [<unknown file> @ <unknown line number>]
00007FFDC3DDF9BA00007FFDC3DDF300 torch_cpu.dll!torch::autograd::Engine::add_thread_pool_task [<unknown file> @ <unknown line number>]
00007FFDC3DE03AD00007FFDC3DDFFD0 torch_cpu.dll!torch::autograd::Engine::evaluate_function [<unknown file> @ <unknown line number>]
00007FFDC3DE4FE200007FFDC3DE4CA0 torch_cpu.dll!torch::autograd::Engine::thread_main [<unknown file> @ <unknown line number>]
00007FFDC3DE4C4100007FFDC3DE4BC0 torch_cpu.dll!torch::autograd::Engine::thread_init [<unknown file> @ <unknown line number>]
00007FFE0BFB0A7700007FFE0BF8A150 torch_python.dll!THPShortStorage_New [<unknown file> @ <unknown line number>]
00007FFDC3DDBF1400007FFDC3DDB780 torch_cpu.dll!torch::autograd::Engine::get_base_engine [<unknown file> @ <unknown line number>]
00007FFE6B4710B200007FFE6B470F70 ucrtbase.dll!beginthreadex [<unknown file> @ <unknown line number>]
00007FFE6E047C2400007FFE6E047C10 KERNEL32.DLL!BaseThreadInitThunk [<unknown file> @ <unknown line number>]
00007FFE6E3ED4D100007FFE6E3ED4B0 ntdll.dll!RtlUserThreadStart [<unknown file> @ <unknown line number>]

I am on PyTorch 1.6.0 and version 0.0.11 of this package.

As I've said elsewhere, great package!

lucidrains commented 3 years ago

@jgreener64 thank you for the error-reproducing script! I missed passing in the device for one of the instantiated pytorch tensors https://github.com/lucidrains/se3-transformer-pytorch/commit/56897f8f0d56b4d22c97f21a80295fafa79c679e It should work now for 0.0.12!

@denjots ^

jgreener64 commented 3 years ago

:100: thanks a lot.

LeMei commented 2 years ago

OK, slight update. This error is triggered when the inputs are not contiguous - forcing the input tensors to be contiguous works around the problem. Not sure if this is expected behaviour or not. Feel free to close if this is expected. i also ran into this issue. Could you explain the solution 'forcing the input tensors to be contiguous'. I can not understand it and know nothing how to solve this issue.

i am waiting for your reply always and thanks for your reply.