DeepGraphLearning / GearNet

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)
MIT License
253 stars 28 forks source link

shape mismatch #24

Closed pearl-rabbit closed 1 year ago

pearl-rabbit commented 1 year ago

I encountered a shape mismatch issue during runtime.

File "/home/admin/anaconda3/envs/test_env/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/layers/conv.py", line 813, in message_and_aggregate
    return update.view(graph.num_node, self.num_relation * self.input_dim)
RuntimeError: shape '[975, 472]' is invalid for input of size 312000

protein structure:

print(protein, protein.node_feature.shape)  # PackedProtein(batch_size=1, num_atoms=[51], num_bonds=[975], num_residues=[51])   torch.Size([51, 21])
Oxer11 commented 1 year ago

Hi, could you provide more contexts about the error? Including what command you're running and what dataset and model you're using.

pearl-rabbit commented 1 year ago

error infomation:

Traceback (most recent call last):
  File "test230424_define_model.py", line 67, in <module>
    output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/models/gearnet.py", line 99, in forward
    edge_hidden = self.edge_layers[i](line_graph, edge_input)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/layers/conv.py", line 91, in forward
    update = self.message_and_aggregate(graph, input)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/layers/conv.py", line 813, in message_and_aggregate
    return update.view(graph.num_node, self.num_relation * self.input_dim)
RuntimeError: shape '[19, 472]' is invalid for input of size 6080

protein: I only loaded one protein, and i provide partial information in 1A0G.pdb.The number of loaded protein residues does not seem to affect the results. PackedProtein(batch_size=1, num_atoms=[5], num_bonds=[19], num_residues=[5]) torch.Size([5, 3])

ATOM      1  N   GLY A   1      62.683  18.043  31.832  1.00 27.29           N  
ATOM      2  CA  GLY A   1      62.540  19.333  31.113  1.00 26.20           C  
ATOM      3  C   GLY A   1      61.709  20.294  31.930  1.00 26.29           C  
ATOM      4  O   GLY A   1      61.503  20.069  33.122  1.00 26.53           O  
ATOM      5  H1  GLY A   1      63.547  17.860  31.944  1.00 27.29           H  
ATOM      6  H2  GLY A   1      62.287  18.100  32.627  1.00 27.29           H  
ATOM      7  H3  GLY A   1      62.302  17.394  31.357  1.00 27.29           H  
ATOM      8  HA2 GLY A   1      63.415  19.715  30.943  1.00 26.20           H  
ATOM      9  HA3 GLY A   1      62.122  19.186  30.250  1.00 26.20           H  
ATOM     10  N   TYR A   2      61.235  21.352  31.285  1.00 25.57           N  
ATOM     11  CA  TYR A   2      60.403  22.375  31.908  1.00 25.80           C  
ATOM     12  C   TYR A   2      59.040  22.446  31.212  1.00 24.57           C  
ATOM     13  O   TYR A   2      58.920  22.239  29.996  1.00 23.52           O  
ATOM     14  CB  TYR A   2      61.066  23.748  31.808  1.00 27.46           C  
ATOM     15  CG  TYR A   2      62.320  23.894  32.630  1.00 30.74           C  
ATOM     16  CD1 TYR A   2      63.564  23.548  32.104  1.00 31.97           C  
ATOM     17  CD2 TYR A   2      62.265  24.368  33.941  1.00 32.03           C  
ATOM     18  CE1 TYR A   2      64.730  23.662  32.861  1.00 33.56           C  
ATOM     19  CE2 TYR A   2      63.429  24.490  34.713  1.00 34.25           C  
ATOM     20  CZ  TYR A   2      64.659  24.131  34.162  1.00 34.25           C  
ATOM     21  OH  TYR A   2      65.812  24.229  34.910  1.00 36.90           O  
ATOM     22  H   TYR A   2      61.392  21.500  30.452  1.00 25.57           H  
ATOM     23  HA  TYR A   2      60.290  22.135  32.841  1.00 25.80           H  
ATOM     24  HB2 TYR A   2      61.279  23.925  30.878  1.00 27.46           H  
ATOM     25  HB3 TYR A   2      60.429  24.424  32.087  1.00 27.46           H  
ATOM     26  HD1 TYR A   2      63.617  23.235  31.230  1.00 31.97           H  
ATOM     27  HD2 TYR A   2      61.444  24.606  34.308  1.00 32.03           H  
ATOM     28  HE1 TYR A   2      65.551  23.424  32.494  1.00 33.56           H  
ATOM     29  HE2 TYR A   2      63.381  24.808  35.586  1.00 34.25           H  
ATOM     30  HH  TYR A   2      65.626  24.526  35.674  1.00 36.90           H  
ATOM     31  N   THR A   3      58.029  22.784  31.994  1.00 22.76           N  
ATOM     32  CA  THR A   3      56.674  22.908  31.512  1.00 20.04           C  
ATOM     33  C   THR A   3      56.145  24.288  31.854  1.00 20.22           C  
ATOM     34  O   THR A   3      56.566  24.902  32.840  1.00 21.09           O  
ATOM     35  CB  THR A   3      55.813  21.835  32.187  1.00 19.90           C  
ATOM     36  OG1 THR A   3      56.348  20.551  31.868  1.00 18.58           O  
ATOM     37  CG2 THR A   3      54.358  21.891  31.725  1.00 19.23           C  
ATOM     38  H   THR A   3      58.116  22.950  32.833  1.00 22.76           H  
ATOM     39  HA  THR A   3      56.647  22.789  30.550  1.00 20.04           H  
ATOM     40  HB  THR A   3      55.829  21.996  33.143  1.00 19.90           H  
ATOM     41  HG1 THR A   3      57.091  20.455  32.247  1.00 18.58           H  
ATOM     42 HG21 THR A   3      53.849  21.198  32.174  1.00 19.23           H  
ATOM     43 HG22 THR A   3      53.983  22.759  31.941  1.00 19.23           H  
ATOM     44 HG23 THR A   3      54.317  21.751  30.766  1.00 19.23           H  
ATOM     45  N   LEU A   4      55.313  24.822  30.972  1.00 18.82           N  
ATOM     46  CA  LEU A   4      54.661  26.099  31.184  1.00 18.80           C  
ATOM     47  C   LEU A   4      53.361  25.744  31.916  1.00 18.45           C  
ATOM     48  O   LEU A   4      52.540  24.988  31.412  1.00 18.08           O  
ATOM     49  CB  LEU A   4      54.363  26.774  29.843  1.00 19.09           C  
ATOM     50  CG  LEU A   4      53.376  27.937  29.779  1.00 19.90           C  
ATOM     51  CD1 LEU A   4      53.899  29.172  30.510  1.00 19.94           C  
ATOM     52  CD2 LEU A   4      53.136  28.257  28.336  1.00 20.91           C  
ATOM     53  H   LEU A   4      55.110  24.447  30.225  1.00 18.82           H  
ATOM     54  HA  LEU A   4      55.209  26.720  31.688  1.00 18.80           H  
ATOM     55  HB2 LEU A   4      55.207  27.091  29.485  1.00 19.09           H  
ATOM     56  HB3 LEU A   4      54.040  26.087  29.239  1.00 19.09           H  
ATOM     57  HG  LEU A   4      52.552  27.678  30.220  1.00 19.90           H  
ATOM     58 HD11 LEU A   4      53.246  29.886  30.447  1.00 19.94           H  
ATOM     59 HD12 LEU A   4      54.052  28.956  31.443  1.00 19.94           H  
ATOM     60 HD13 LEU A   4      54.732  29.459  30.105  1.00 19.94           H  
ATOM     61 HD21 LEU A   4      52.510  28.995  28.268  1.00 20.91           H  
ATOM     62 HD22 LEU A   4      53.974  28.504  27.915  1.00 20.91           H  
ATOM     63 HD23 LEU A   4      52.768  27.479  27.889  1.00 20.91           H  
ATOM     64  N   TRP A   5      53.244  26.216  33.147  1.00 18.81           N  
ATOM     65  CA  TRP A   5      52.090  25.958  33.974  1.00 19.95           C  
ATOM     66  C   TRP A   5      51.552  27.327  34.334  1.00 20.86           C  
ATOM     67  O   TRP A   5      52.060  27.978  35.250  1.00 19.16           O  
ATOM     68  CB  TRP A   5      52.518  25.197  35.224  1.00 21.36           C  
ATOM     69  CG  TRP A   5      51.379  24.766  36.083  1.00 23.63           C  
ATOM     70  CD1 TRP A   5      50.043  24.813  35.774  1.00 24.43           C  
ATOM     71  CD2 TRP A   5      51.468  24.189  37.391  1.00 25.64           C  
ATOM     72  NE1 TRP A   5      49.305  24.293  36.805  1.00 25.53           N  
ATOM     73  CE2 TRP A   5      50.148  23.904  37.810  1.00 25.61           C  
ATOM     74  CE3 TRP A   5      52.535  23.882  38.250  1.00 25.32           C  
ATOM     75  CZ2 TRP A   5      49.866  23.330  39.050  1.00 27.27           C  
ATOM     76  CZ3 TRP A   5      52.254  23.312  39.483  1.00 27.22           C  
ATOM     77  CH2 TRP A   5      50.928  23.040  39.872  1.00 28.11           C  
ATOM     78  H   TRP A   5      53.844  26.702  33.527  1.00 18.81           H  
ATOM     79  HA  TRP A   5      51.420  25.418  33.526  1.00 19.95           H  
ATOM     80  HB2 TRP A   5      53.026  24.415  34.958  1.00 21.36           H  
ATOM     81  HB3 TRP A   5      53.113  25.758  35.747  1.00 21.36           H  
ATOM     82  HD1 TRP A   5      49.690  25.148  34.982  1.00 24.43           H  
ATOM     83  HE1 TRP A   5      48.448  24.222  36.818  1.00 25.53           H  
ATOM     84  HE3 TRP A   5      53.413  24.057  37.997  1.00 25.32           H  
ATOM     85  HZ2 TRP A   5      48.992  23.150  39.311  1.00 27.27           H  
ATOM     86  HZ3 TRP A   5      52.952  23.106  40.062  1.00 27.22           H  
ATOM     87  HH2 TRP A   5      50.768  22.656  40.704  1.00 28.11           H
TER
END

model definition:

# protein
protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
_protein = data.Protein.pack([protein])
protein = graph_construction_model(_protein)

# model
gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512],
                              num_relation=7, edge_input_dim=59, num_angle_bin=8,
                              batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")
pthfile = 'models/angle_gearnet_edge.pth'
net = torch.load(pthfile)
gearnet_edge.load_state_dict(net)

# written according to the document,
truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])
item = {"graph": protein}
 if transform:
     item = transform(item)
 protein = item['graph']

#output
with torch.no_grad():
    output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None)
Oxer11 commented 1 year ago

It seems that the shape of edge feature in your protein is 19*40. Have you tried the following graph_construction_model?

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

Note that you need to set the edge feature in graph_contruction_model as gearnet to get a feature of 19*59.

pearl-rabbit commented 1 year ago

Thank you for patiently answering my question. It has been resolved.