generatebio / chroma

A generative model for programmable protein design
Apache License 2.0
696 stars 90 forks source link

Tensors not on the same device #40

Closed Bloeci closed 9 months ago

Bloeci commented 10 months ago

When i run one of the minimal examples

# test.py
from chroma import Chroma, Protein

chroma = Chroma()
protein = Protein('1GFP')
protein = chroma.design(protein)

protein.to("1GFP-redesign.cif")

I get an error that some tensors are not on the same devices:

(chroma_env) Ξ software/chroma git:(main) ▶ python3 test.py
Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt
Loaded from cache
cuda
Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt
Loaded from cache
Traceback (most recent call last):
  File "/home/iwe34/software/chroma/test.py", line 5, in <module>
    protein = chroma.design(protein)
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/models/chroma.py", line 532, in design
    X_sample, S_sample, _ = self.design_network.sample(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/anaconda3/envs/chroma_env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/data/xcs.py", line 114, in new_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/models/graph_design.py", line 826, in sample
    node_h, edge_h, edge_idx, mask_i, mask_ij = self.encode(X, C, t=t)
                                                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/data/xcs.py", line 114, in new_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/models/graph_design.py", line 536, in encode
    node_h, edge_h, edge_idx, mask_i, mask_ij = self.encoder(
                                                ^^^^^^^^^^^^^
  File "/home/iwe34/anaconda3/envs/chroma_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/data/xcs.py", line 114, in new_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/models/graph_design.py", line 1237, in forward
    node_h, edge_h, edge_idx, mask_i, mask_ij = self._checkpoint(
                                                ^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/models/graph_design.py", line 1253, in _checkpoint
    return module(*args)
           ^^^^^^^^^^^^^
  File "/home/iwe34/anaconda3/envs/chroma_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/iwe34/software/chroma/chroma/layers/structure/protein_graph.py", line 224, in forward
    node_h_l = node_h_l - self.__getattr__(f"node_means_{i}")
               ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

When I try to fix that putting everything on the GPU for this particular line, the error occur on another line (next shown below):

  File "/home/iwe34/software/chroma/chroma/layers/structure/protein_graph.py", line 1443, in forward
    h = torch.exp(-(((h.unsqueeze(-1) - rbf_centers) / self.std) ** 2))
                      ~~~~~~~~~~~~~~~~^~~~~~~~~~~~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

That looks to me that a lot of other tensors also would show the same problem and that it is not an easy way to fix, because the device type don't seemed to be share across all classes :(

wujiewang commented 10 months ago

Thanks for the question.

The fix is easy, you just need to specify the device in Protein.

chroma = Chroma()
protein = Protein('1GFP', device='cuda')
protein = chroma.design(protein)

protein.to("1GFP-redesign.cif")
wujiewang commented 9 months ago

close this due to inactivity. Feel free to reopen.