Layne-Huang / PMDM

96 stars 21 forks source link

mat1 and mat2 shapes cannot be multiplied (389x10 and 8x128) #27

Closed chefbaker4 closed 2 months ago

chefbaker4 commented 3 months ago

Hi,

I was wondering why I am receiving this matrix multiplication error. I have downloaded the code and followed step by step the process. I am trying to train the model from scratch and after processing the crossDocked data, I run this command

python -u train.py --config configs/crossdock_epoch.yml

This is throwing an error on what seems like the very first iteration.

here is the whole stack trace if that is helpful.... Traceback (most recent call last): File "/home/user/PMDM/train.py", line 416, in avg_val_loss = validate(it) File "/home/user/PMDM/train.py", line 186, in validate loss = model( File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/user/PMDM/models/epsnet/MDM_pocket_coor_shared.py", line 635, in forward net_out = self.net( File "/home/user/PMDM/models/epsnet/MDM_pocket_coor_shared.py", line 374, in net ligand_atom_feature = self.ligand_encoder( File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/user/PMDM/models/encoders/schnet.py", line 95, in forward h = self.emblin(node_attr) File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/user/miniconda3/envs/pmdm/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (389x10 and 8x128)

Layne-Huang commented 2 months ago

Thank you very much for your interests. Please check the dimension of the atom features. You could try to set the "atom_type" as 8.

chefbaker4 commented 2 months ago

Thank you for your help, I was able to get the model to train. For anyone who encounters similar, this is the solution I found.

  1. atom_type must be set to 10 in the crossdock_epoch.yml file
  2. The MDM_pocket_coor_shared.py file should be update as follows on line 90 and 91....

    self.atom_type_input_dim = config.atom_type if 'atom_type' in config else 8
    self.atom_out_dim = config.atom_type if 'atom_type' in config else 8