DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

Shape mismatch for node attribute `atom_feature` in GCPN #117

Closed jannisborn closed 2 years ago

jannisborn commented 2 years ago

Thanks for releasing torchdrug 0.1.3!

I updated to the new version and see improved behavior in many places. However, unfortunately some functionalities that were stable in 0.1.2 are failing now.

For example, when performing inference with a trained model, torchdrug/data/graph.py fails in L159:

self = PackedMolecule(batch_size=32, num_atoms=[2, 2, 2, ..., 2, 2, 2], num_bonds=[2, 2, 2, ..., 2, 2, 2]), key = 'atom_feature'
value = tensor([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

    def _check_attribute(self, key, value):
        for type in self._meta_contexts:
            if "reference" in type:
                if value.dtype != torch.long:
                    raise TypeError("Tensors used as reference must be long tensors")
            if type == "node":
                if len(value) != self.num_node:
                    raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" %
>                                    (key, self.num_node, value.shape))
E                   ValueError: Expect node attribute `atom_feature` to have shape (64, *), but found torch.Size([32, 18])

The error occurs when doing the tutorial about molecule generation here. After training the model as described in the tutorial, the inference is functional. However, after loading a saved checkpoint as described in the next step of the tutorial (solver.load("path_to_dump/graphgeneration/gcpn_zinc250k_1epoch.pkl")), the sample generation raises with the above error.

Here's the full trace in torchdrug:

../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py:1353: in generate
    new_graph = self._apply_action(graph, off_policy, max_resample, verbose=1)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py:1283: in _apply_action
    meta_dict=meta_dict, **data_dict)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/data/molecule.py:610: in __init__
    offsets=offsets, atom_type=atom_type, bond_type=bond_type, **kwargs)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/data/graph.py:1101: in __init__
    num_relation=num_relation, **kwargs)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/data/molecule.py:73: in __init__
    self.atom_feature = torch.as_tensor(atom_feature, device=self.device)
../../miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/data/graph.py:159: in __setattr__
    self._check_attribute(key, value)

The error never occurs in the first but only in the second iteration. I'm not sure what's going wrong but this error consistently occurs in version 0.1.3 and it occurs irrespective of whether the model was trained in 0.1.2 or 0.1.3.

Could you please advise how to load a trained model for inference in torchdrug 0.1.3? Thanks

jannisborn commented 2 years ago

Hi @KiddoZhu, a soft push on this matter - keep in mind that your tutorials are failing due to this matter

KiddoZhu commented 2 years ago

Hi! This is the same as shape mismatch for edge features in retrosynthesis, since both are generative models and change the structure of the molecules.

I just fixed it in 9fac912.

jannisborn commented 2 years ago

Thanks @KiddoZhu for the bugfix, I'm closing this issue. While testing the current tip of master I found another bug in the property optimization. I provided a hotfix in a separate PR #125. Please have a look!