DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

AssertionError in num_relation #82

Open jannisborn opened 2 years ago

jannisborn commented 2 years ago

In the property optimization setting, it can easily happen that an AssertionError is raised in https://github.com/DeepGraphLearning/torchdrug/blob/d187dd85ed38042bc7e76e7a8c6f26d0f931cd3b/torchdrug/layers/conv.py#L422

I investigated and found that graph.num_relation was 3 whereas self.num_relation was 4. The reason for graph.num_relation to be lowered was caused by this line: https://github.com/DeepGraphLearning/torchdrug/blob/d187dd85ed38042bc7e76e7a8c6f26d0f931cd3b/torchdrug/tasks/generation.py#L1345

where kekulize is hard-coded to True. Consequently, the aromatic bonds are removed from the bond count. I would not like to kekulize my molecules and I launched the training with that specification, however, the package does not allow to control this hardcoded value.

Here's the full error trace

Traceback (most recent call last):
  File "/Users/jab/gt4sd/gt4sd-core/src/gt4sd/training_pipelines/torchdrug/gcpn/core.py", line 146, in train
    solver.train(num_epoch=epochs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/core/engine.py", line 143, in train
    loss, metric = model(batch)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py", line 872, in forward
    _loss, _metric = self.reinforce_forward(batch)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py", line 1028, in reinforce_forward
    verbose=1,
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py", line 1807, in generate
    new_graph = self._apply_action(graph, off_policy, max_resample, verbose=1)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py", line 1598, in _apply_action
    ) = self._sample_action(graph, off_policy)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/tasks/generation.py", line 1316, in _sample_action
    output = model(graph, graph.node_feature.float())
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/models/gcn.py", line 153, in forward
    hidden = layer(graph, layer_input)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/layers/conv.py", line 91, in forward
    update = self.message_and_aggregate(graph, input)
  File "/Users/jab/miniconda3/envs/gt4sd/lib/python3.7/site-packages/torchdrug/layers/conv.py", line 423, in message_and_aggregate
    assert graph.num_relation == self.num_relation
AssertionError
KiddoZhu commented 2 years ago

This looks like a discrepancy between generation part and the RGCN model. We will fix it.

If I recall correctly, the generative models don't perform well if you disable kekulize. It is not very easy to predict a ring correctly using aromatic bonds. Could you confirm that? @shichence

jannisborn commented 2 years ago

Thanks @KiddoZhu! Hm, I see that the generation might be easier if the molecules are kekulized, but still I feel that this should be a user decision.

Especially if the dataset constructor allows to set this option. The bare minimum would be to raise an Error that property optimization does not work without kekulization. I had to dig a while to find the cause of this error.

At the same time, it seems necessary that in the dataset constructor,node_features is set to symbol and not to default. I'm not sure why this is but I got some shape mismatches in case I changed it to default.

KiddoZhu commented 2 years ago

All autoregressive generative models take symbol as node features. This is because other features may not be well defined for partial molecules during the generation. For kekulization, if I recall correctly, the original implementation of both GCPN and GraphAF use kekulization, and we follow that as default.

We will try to modify the interface so that users don't need to debug such details.