DeepGraphLearning / GearNet

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)
MIT License
263 stars 27 forks source link

RuntimeError: shape '[31, 147]' is invalid for input of size 651 in extracting embbeding from pdb files #59

Open blazexpire opened 9 months ago

blazexpire commented 9 months ago

Thank you for your great work. When I followed the method in the previous issue to implement feature extraction for pdb files, I encountered the following problem. Traceback (most recent call last): File "script\tcr_extract.py", line 48, in <module> output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None) File "C:\Users\mercy\anaconda3\envs\proj_v1\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\mercy\anaconda3\envs\proj_v1\lib\site-packages\torchdrug\models\gearnet.py", line 95, in forward hidden = self.layers[i](graph, layer_input) File "C:\Users\mercy\anaconda3\envs\proj_v1\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\mercy\anaconda3\envs\proj_v1\lib\site-packages\torchdrug\layers\conv.py", line 91, in forward update = self.message_and_aggregate(graph, input) File "C:\Users\mercy\anaconda3\envs\proj_v1\lib\site-packages\torchdrug\layers\conv.py", line 813, in message_and_aggregate return update.view(graph.num_node, self.num_relation * self.input_dim) RuntimeError: shape '[31, 147]' is invalid for input of size 651 I used exactly the same code as in the issue, just modified the pthfile to mc_gearnet_edge.pth. The error appears on the last line. output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None) The pdb file I used had 31 residues. In this error, I guessed that 651 of them were 31*21. But I don't know why this is the case. I hope I can get some help. I have been troubled by this problem for a week. Thanks for all.

Oxer11 commented 9 months ago

Hi, I'm wondering how you use the graph_construction_model. Usually, we use graph_contruction_model to define a graph with 7 edge types and perform relational message passing. The implementation of relational message passing is a little different from what you thought. You can find the code here. I guess that the error your reported is because you don't have the graph constucted correctly with 7 edge types.

steffanpaul commented 9 months ago

Hello! I am running into the same issue as above. I am using the graph_construction_model as used in the torchprotein tutorial. However, I run into the same import error as above. When I run the same code but using the downloaded EC proteins from the tutorial, I don't run into this issue. Is this to do with incorrect use of the data.Protein.from_pdb function?

My code is as below (I'm reading in a pdb downloaded from rcsb).

import os, sys
pdb_dir = '/n/groups/marks/users/steffan/calibration_project/structures/pdb_repo'
pdb_file = os.path.join(pdb_dir, os.listdir(pdb_dir)[10])
pthdir = '/n/groups/marks/users/steffan/calibration_project/structure_kernels/gearnet/checkpoints'
pthfile = os.path.join(pthdir, 'angle_gearnet_edge.pth')

from torchdrug import data
from torchdrug import layers
from torchdrug.layers import geometry
from torchdrug import models
import torch

# protein
protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
_protein = data.Protein.pack([protein])
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                   edge_feature="gearnet")
protein_ = graph_construction_model(_protein)

# model
gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512],
                              num_relation=7, edge_input_dim=59, num_angle_bin=8,
                              batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")

net = torch.load(pthfile, map_location=torch.device('cpu'))
gearnet_edge.load_state_dict(net)

#output
gearnet_edge.eval()
with torch.no_grad():
    output = gearnet_edge(protein_, protein_.node_feature.float(), all_loss=None, metric=None)

And the error readout is

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[23], line 34
     32 gearnet_edge.eval()
     33 with torch.no_grad():
---> 34     output = gearnet_edge(protein_, protein_.node_feature.float(), all_loss=None, metric=None)

File /n/groups/marks/users/steffan/.conda/envs/seqmodels-calibration/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /n/groups/marks/users/steffan/.conda/envs/seqmodels-calibration/lib/python3.9/site-packages/torchdrug/models/gearnet.py:95, in GeometryAwareRelationalGraphNeuralNetwork.forward(self, graph, input, all_loss, metric)
     92     edge_input = line_graph.node_feature.float()
     94 for i in range(len(self.layers)):
---> 95     hidden = self.layers[i](graph, layer_input)
     96     if self.short_cut and hidden.shape == layer_input.shape:
     97         hidden = hidden + layer_input

File /n/groups/marks/users/steffan/.conda/envs/seqmodels-calibration/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /n/groups/marks/users/steffan/.conda/envs/seqmodels-calibration/lib/python3.9/site-packages/torchdrug/layers/conv.py:91, in MessagePassingBase.forward(self, graph, input)
     89     update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input)
     90 else:
---> 91     update = self.message_and_aggregate(graph, input)
     92 output = self.combine(input, update)
     93 return output

File /n/groups/marks/users/steffan/.conda/envs/seqmodels-calibration/lib/python3.9/site-packages/torchdrug/layers/conv.py:813, in GeometricRelationalGraphConv.message_and_aggregate(self, graph, input)
    809     edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0,
    810                               dim_size=graph.num_node * graph.num_relation)
    811     update += edge_update
--> 813 return update.view(graph.num_node, self.num_relation * self.input_dim)

RuntimeError: shape '[276, 147]' is invalid for input of size 5796
Oxer11 commented 9 months ago

Hi! I think the error is that you forget to set protein_.view = 'residue'. This is done by using the transforms.ProteinView in the tutorial. If you don't do this, the node_feature you feed into the model is atom_feature instead of residue_feature.

gtamer2 commented 8 months ago

I am following the tutorial and having the equivalent issue on the update.view operation in the graph convolution layer. I have confirmed I am using residue_feature.

The graph has attributes PackedProtein(batch_size=1, num_atoms=[350], num_bonds=[6816], num_residues=[350])

Relevant dimensions from the convolution are:

self.num_relation = 7 # from config
self.input_dim =  21    # from config
update.shape= torch.Size([2450, 67]) # Note that 2450 = 350 * 7
graph.num_node= tensor(350, device='cuda:0')
self.num_relation * self.input_dim = 147

and we get

 return update.view(graph.num_node, self.num_relation * self.input_dim)
RuntimeError: shape '[350, 147]' is invalid for input of size 164150

Note that 350 * 147 gives a size of 51,450, and 51450 / 21 gives 2450, which matches the update dimension.

So the mismatch is we expect input_dim = 21 (number of amino acids), but we have 67. I have copied the protein transforms and graph construction verbatum from tutorial:

    truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
    protein_view_transform = transforms.ProteinView(view="residue")
    transform = transforms.Compose([truncate_transform, protein_view_transform])
    ....
    graphs = data.Protein.pack(proteins)
    graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")
    graphs = graph_construction_model(graphs)

What might be the problem here?

Oxer11 commented 8 months ago

Hi, sorry for the confusion. Could you please check again whether you pass the residue_feature instead of atom_feature into the model? If this works correctly, I think the shape of input and update tensors should be [350, 21] and [2450, 21]. I can't figure out where the number 67 in [2450, 67] comes from. I guess that you probably need to useproteins = [transform(protein) for protein in proteins] before the data.Protein.pack operation? Or you need to ensure proteins are sampled from a customized data.ProteinDataset, which will feed the protein into transform when sampling it.

gtamer2 commented 8 months ago

You are correct -- turns out in the process of fixing something else, I was passing in my transform object to my the self.load_pdbs fxn for my dataset which inherits from torch.data.ProteinDataset, but was not actually applying the transform in the __getitem__ method.

I re-added that back in and got this working.

Thank you!