DS4SD / MolGrapher

MolGrapher: Graph-based Visual Recognition of Chemical Structures
https://arxiv.org/abs/2308.12234
MIT License
42 stars 1 forks source link

Error when loading model #6

Closed sincelover closed 4 months ago

sincelover commented 4 months ago

Hi, dear developers! Thank you very much for code sharing!

I have the following problem when running the code(bash molgrapher/scripts/annotate/run.sh). When I use Model gc_gcn_model.ckpt and gc_stereo_model.ckpt,I get the following error. Only when I use Model gc_no_stereo_model.ckpt,The program can run. May I ask how to fix it?

Traceback (most recent call last):
  File "/home/gongjunyu/MolGrapher-main/molgrapher/scripts/annotate/predict_molgrapher.py", line 329, in <module>
    main()
  File "/home/gongjunyu/MolGrapher-main/molgrapher/scripts/annotate/predict_molgrapher.py", line 325, in main
    proceed_batch(args, _batch_images_paths)
  File "/home/gongjunyu/MolGrapher-main/molgrapher/scripts/annotate/predict_molgrapher.py", line 96, in proceed_batch
    model = GraphRecognizer(
            ^^^^^^^^^^^^^^^^
  File "/home/gongjunyu/MolGrapher-main/molgrapher/models/graph_recognizer.py", line 53, in __init__
    self.graph_classifier = GraphClassifier.load_from_checkpoint(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gongjunyu/anaconda3/envs/molgrapher1/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 1561, in load_from_checkpoint
    loaded = _load_from_checkpoint(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gongjunyu/anaconda3/envs/molgrapher1/lib/python3.11/site-packages/pytorch_lightning/core/saving.py", line 89, in _load_from_checkpoint
    model = _load_state(cls, checkpoint, strict=strict, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gongjunyu/anaconda3/envs/molgrapher1/lib/python3.11/site-packages/pytorch_lightning/core/saving.py", line 169, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gongjunyu/anaconda3/envs/molgrapher1/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GraphClassifier:
        Unexpected key(s) in state_dict: "gnn.conv1.bias", "gnn.conv1.lin.weight", "gnn.conv2.bias", "gnn.conv2.lin.weight", "gnn.conv3.bias", "gnn.conv3.lin.weight", "gnn.conv4.bias", "gnn.conv4.lin.weight", "criterion_bonds.weight".
        size mismatch for gnn.mlp_atoms.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
        size mismatch for gnn.mlp_atoms.3.weight: copying a param with shape torch.Size([141, 256]) from checkpoint, the shape in current model is torch.Size([182, 256]).
        size mismatch for gnn.mlp_atoms.3.bias: copying a param with shape torch.Size([141]) from checkpoint, the shape in current model is torch.Size([182]).
        size mismatch for gnn.mlp_bonds.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
        size mismatch for gnn.mlp_bonds.3.weight: copying a param with shape torch.Size([5, 256]) from checkpoint, the shape in current model is torch.Size([6, 256]).
        size mismatch for gnn.mlp_bonds.3.bias: copying a param with shape torch.Size([5]) from checkpoint, the shape in current model is torch.Size([6]).
lucas-morin commented 4 months ago

Hello!

Thank you, and sorry for the lack of information in the instructions. To use gc_stereo_model, you need to: 1- Change the selected model by commenting or uncommenting the following lines: https://github.com/DS4SD/MolGrapher/blob/dfc83eb6c07fb9d1d29d3a7fde6089add924760a/molgrapher/models/graph_recognizer.py#L46

2- In the configuration file, modify nb_bonds_classes from 6 to 8: https://github.com/DS4SD/MolGrapher/blob/dfc83eb6c07fb9d1d29d3a7fde6089add924760a/data/config_dataset_graph_2.json#L9

3- In run_predict.sh, replace the run argument --no-assign-stereo by assign-stereo: https://github.com/DS4SD/MolGrapher/blob/dfc83eb6c07fb9d1d29d3a7fde6089add924760a/molgrapher/scripts/annotate/run_predict.sh#L12

To test it, I added to the repository an image of a molecule with stereo-chemistry which should be correctly predicted.

To use gc_gcn_model, you need to: 1- Change the selected model by commenting or uncommenting the following lines: https://github.com/DS4SD/MolGrapher/blob/49fc3e95cf25a3cf60223947074b67dbe0d0aee2/molgrapher/models/graph_recognizer.py#L47

2- In the configuration file, modify nb_bonds_classes from 6 to 5, and nb_atoms_classes from 182 to 141: https://github.com/DS4SD/MolGrapher/blob/dfc83eb6c07fb9d1d29d3a7fde6089add924760a/data/config_dataset_graph_2.json#L9

3- Change self.gcn_on = False to self.gcn_on = True, here: https://github.com/DS4SD/MolGrapher/blob/49fc3e95cf25a3cf60223947074b67dbe0d0aee2/molgrapher/models/graph_classifier.py#L85

(run_predict.sh should be run with the default argument --no-assign-stereo.)

I hope this can help! Best,

Lucas

sincelover commented 4 months ago

Thank you so much for such a detailed response!

Following your method, I solved the problem with model gc_stereo_model. But when i use gc_gcn_model,The following problems still occur. I wonder if there are still parameters that need to be adjusted, or if I'm making an error in my operation.

Thank you very much.

RuntimeError: Error(s) in loading state_dict for GraphClassifier:
        Unexpected key(s) in state_dict: "criterion_bonds.weight".
        size mismatch for gnn.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
        size mismatch for gnn.conv1.lin.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
        size mismatch for gnn.conv2.lin.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
lucas-morin commented 4 months ago

I forgot to mention it, but did you pull the latest changes from the repository?

Best, Lucas

sincelover commented 4 months ago

Finally, thanks again. I fetched the latest file, graph_classifier.py, and it finally worked fine!

lucas-morin commented 4 months ago

You're welcome!

sincelover commented 4 months ago

Finally, can I ask under what conditions each of these three models was trained?

lucas-morin commented 4 months ago

Yes! Models are trained on synthetic images generated using MolDepictor. (The dataset is also available on Hugging Face.) We used 3 NVIDIA A100 GPUs and an ADAM optimizer with a learning rate of 1e-4, decayed after 5000 iterations by a factor of 0.8. The training parameters can be set here: https://github.com/DS4SD/MolGrapher/blob/49fc3e95cf25a3cf60223947074b67dbe0d0aee2/data/config_training_graph_2.json#L1 Does this answer your question?