rampasek / GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer
MIT License
643 stars 114 forks source link

mismatch shape on inference #35

Closed almog2065 closed 1 year ago

almog2065 commented 1 year ago

Hi, thanks for the great project. While trying to get inference after training the small model with PCQ dataset, I got an error with mismatch shapes. I just ran all the necessary commands in the readme file, and after the line of inference it showed up. I add a screenshot of the error.

Precomputing Positional Encoding statistics: ['RWSE'] for all graphs... ...estimated to be undirected: True 77%|██████████████████████████ | 282370/368014 [05:18<01:37, 881.32it/s]100%|██████████████████████████████████| 368014/368014 [06:54<00:00, 888.64it/s] Done! Took 00:07:01.53 [*] Loading from pretrained model: pretrained/pcqm4m-GPS+RWSE.deep/0/ckpt/148.ckpt Traceback (most recent call last): File "/home//GraphGPS/main.py", line 146, in model = init_model_from_pretrained( File "/home//GraphGPS/graphgps/finetuning.py", line 142, in init_model_from_pretrained model.load_state_dict(model_dict) File "/home/almogben/miniconda3/envs/graphgps/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for GraphGymModule: size mismatch for model.encoder.node_encoder.encoder1.atom_embedding_list.1.weight: copying a param with shape torch.Size([4, 236]) from checkpoint, the shape in current model is torch.Size([5, 236]).

Also I want to train a model on PCQ and after that to do fine-tuning to the model with ZINC, So after I trained the model where it is saved? Should I change your code? or just after thhe training the fine-tuning is automatically on the previous trained model?

Thanks!

rampasek commented 1 year ago

Hi Almog,

From what I could find, this is caused by a change in Atom featurisation specs in OGB package, particularly this commit: https://github.com/snap-stanford/ogb/commit/184e2bfd8386bc71ed32800111b15e2b91b44791, that became a part of the OGB v1.3.6 release in April.

That means that unfortunately the checkpoints I published are not directly compatible with the newest OGB package. Try downgrading to OGB v1.3.5 or patch the 'possible_chirality_list' featurisation locally.

You can also have a look at https://github.com/datamol-io/graphium project, that implements GPS(++) and is actively supported.

I hope this helps! Ladislav