DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.42k stars 200 forks source link

results of G2G don't match the article #131

Closed DimGorr closed 1 year ago

DimGorr commented 1 year ago

Training the model G2G center identification on 50 epochs and synthon completion on 10 epochs with task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=3, num_synthon_beam=10, max_prediction=20). and all the other parameters are as in the tutorial except for adding features in center identification task: feature=("graph", "atom", "bond") and in synthon completion task feature=("graph", "atom"). Eventually, I got the following results image training it on 100 and 50 epochs for center identification and synthon completion I got image So I saw in one of the previous issues that the results are lower than in the article because of too few epochs and it was told there to use 100 and 50 epochs respectively. So do you know what the problem is?

KiddoZhu commented 1 year ago

Do you use the correct feature for the datasets like the tutorial? If so, I guess the feature in center classification and synthon completion tasks are crucial for the performance. Maybe we can ask @shichence, the author of G2Gs?

DimGorr commented 1 year ago

I changed the dataset feature from node_feature to atom_feature as you proposed in one of the previous issues. Except for this, everything is right as in the tutorial.

Initially, I added the feature "atom" to the features of the synthon completion task. However, now I tried to do everything as in the tutorial using 100 epochs for center identification; 50 epochs for the synthon completion, and task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=3, num_synthon_beam=10, max_prediction=20). And eventually, I got the results 22:36:00 top-1 accuracy: 0.374486 22:36:00 top-10 accuracy: 0.72428 22:36:00 top-3 accuracy: 0.561728 22:36:00 top-5 accuracy: 0.63786

Which is much lower than it should be https://torchdrug.ai/docs/benchmark/retrosynthesis.html May it be because of overfitting?

And one more question: in the tutorial you use reaction_task = tasks.CenterIdentification(reaction_model, feature=("graph", "atom", "bond")) synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",)) So as far as I understand you consider reaction class unknown here (let me know if I misunderstood smth). And eventually, you get top-1 accuracy: 0.47541 top-3 accuracy: 0.741803 top-5 accuracy: 0.827869 top-10 accuracy: 0.879098 which is higher than https://torchdrug.ai/docs/benchmark/retrosynthesis.html for reaction class unknown. So my question is do I understand correctly how you define reaction class?

DimGorr commented 1 year ago

@KiddoZhu I with the help of Chence Shi, I got close to the reported results. If you need I can send you my eventual hyperparameters. And in general, seems like batch sizes and learning rates are extremely important: In the synthon completion task they are lr=1e-4; batch_size=64 Retrosynthesis prediction task: lr=1e-4; batch_size=16 It would be really nice if you could change it in the tutorials so as people wouldn't suffer it future:)

delpqhiz commented 1 year ago

@DimGorr can you list your reproduced results and also the hyperparameters here? I have tried to change the learning rates and batch sizes as you suggested but still cannot reproduce the reported results.

DimGorr commented 1 year ago

@delpqhiz Hi! Sure. here are the results (they are a bit lower than the reported ones but at least they are close): top-1 accuracy: 0.422131 top-10 accuracy: 0.797131 top-3 accuracy: 0.651639 top-5 accuracy: 0.729508 Here are the hyperparameters:

reaction_dataset = datasets.USPTO50k("~/molecule-datasets/",
                                     atom_feature="center_identification", as_synthon = False,
                                     with_hydrogen = False, kekulize=True,  verbose = 1)

torch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
                    hidden_dims=[256, 256, 256, 256], 
                    batch_norm = False,  short_cut = True,
                    num_relation=reaction_dataset.num_bond_type,
                    concat_hidden = False)

reaction_task = tasks.CenterIdentification(reaction_model,   
                                           feature=( "graph", "atom", "bond"), num_mlp_layer=2)                                                                        

reaction_optimizer = torch.optim.Adam(reaction_task.parameters(), lr=1e-3, weight_decay = 0)
reaction_solver = core.Engine(reaction_task, reaction_train, reaction_valid,
                              reaction_test, reaction_optimizer,
                               batch_size=128, gradient_interval = 1, log_interval = 300)

reaction_solver.train(num_epoch=100)   # perhaps it was 50, not sure
reaction_solver.evaluate("valid")
reaction_solver.save("g2gs_reaction_model1.pth")     

synthon_dataset = datasets.USPTO50k("~/molecule-datasets/", as_synthon=True,
                                    atom_feature="synthon_completion", with_hydrogen = False,
                                    kekulize=True, verbose = 1)
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()
synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
                            hidden_dims=[512, 512, 512],
                            num_relation=synthon_dataset.num_bond_type,
                            batch_norm = False, short_cut = True, concat_hidden = False)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph"), num_mlp_layer=4, mlp_act = 'tanh')
synthon_optimizer = torch.optim.Adam(synthon_task.parameters(), lr=1e-4, weight_decay = 0)
synthon_solver = core.Engine(synthon_task, synthon_train, synthon_valid,
                             synthon_test, synthon_optimizer,
                              batch_size=64, gradient_interval = 1, log_interval = 100)                         
synthon_solver.train(num_epoch=100)
synthon_solver.evaluate("valid")
synthon_solver.save("g2gs_synthon_model1.pth")

task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,
                            num_synthon_beam=5, max_prediction=10)
lengths = [len(reaction_valid) // 10,
           len(reaction_valid) - len(reaction_valid) // 10]
reaction_valid_small = torch_data.random_split(reaction_valid, lengths)[0]
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, reaction_train, reaction_valid_small, reaction_test,
                     optimizer,  batch_size=16)
solver.load("g2gs_reaction_model1.pth", load_optimizer=False)
solver.load("g2gs_synthon_model1.pth", load_optimizer=False)
solver.evaluate("valid")

Note that in the synthon completion task I used mlp activation='tanh' after Chence Shi advice. I just a bit changed the initial code. I will look into it for a better performance a bit later as for now I'm satisfied with what I got. And please let me know if you succeed in reproducing the same results:)

delpqhiz commented 1 year ago

@DimGorr Thank you for sharing your results and the script! It's very helpful :) I will let you know if I reproduce the same results.

delpqhiz commented 1 year ago

@DimGorr I just want you to know that I have reproduced the results with your scripts! I also changed the mlp activation='tanh' for all the mlp layers in the synthon completion module.

With hidden layers [512, 512, 512, 512], learning rates 1e-4, and epochs 100 for both modules, I finally got the results below, top-1 accuracy: 0.457242 top-10 accuracy: 0.692103 top-3 accuracy: 0.590426 top-5 accuracy: 0.629705

The results are still slightly lower than the paper. But I thought better results are possible by fine-tuning the parameters.

However, I still have another question about the results of G2G. I found an important issue about information leakage from other repositories: https://github.com/uta-smile/RetroXpert/issues/15 . Therefore, I also tried to run G2G on the canonical dataset provided by another method (I change the code for the dataset). With the hyper-parameters same as the above experiment, I got the results below,

top-1 accuracy: 0.275445 top-10 accuracy: 0.563178 top-3 accuracy: 0.437026 top-5 accuracy: 0.495392

I'm not sure whether the performance drop here is due to the information leakage issue mentioned in the link above. If it is, I wonder whether G2Gs have any plan to correct the information leakage issue in the dataset as other methods did. Thanks!

DimGorr commented 1 year ago

@delpqhiz thank you for letting me know! It's weird that top-k accuracies for k>2 decreased after your modification but overall I believe it works.

Concerning your second problem, I will try to look into it. Actually, I also have one more issue; perhaps you will be able to help. Either I misunderstand the beam search or smth is wrong. When I increase num_synthon_beam the overall results decrease drastically...Have you experienced the same? And Do I understand correctly that num_synthon_beam is beam width because in this case, it's kind of strange?

delpqhiz commented 1 year ago

@DimGorr For your issue about beam search, I haven't tried to increase the num_synthon_beam yet. But I think the performance drop with a larger num_synthon_beam might be expected. You can check this paper (https://openreview.net/pdf?id=BkE8NjCqYm) which discusses the effect of beam size on performance :)

DimGorr commented 1 year ago

@delpqhiz thank you so much for the article! Then I have another question (sorry for that):) I a bit changed the center identification task which shows better performance. However, the eventual retrosynthesis task shows worse results... I thought it's related to beam search:) but now I really don't know what's wrong with that. Maybe you have any ideas why it might happen because it sounds unreasonable to me?

delpqhiz commented 1 year ago

@DimGorr I agree with you that a better performance in the center identification task should lead to better overall performance. I have no idea why it happens. Did you change the number of centers in your script? Your script limits the number of centers to be 2, which might be too small.

DimGorr commented 1 year ago

@delpqhiz do you mean by the number of centers num_synthon? Because if I understand correctly exactly this part in Retrosynthesis prediction task is asserted only for 1 or 2 synthons. Actually, I just tried to find out how many graphs have more than 2 synthons ang got that non of them, so I'm not sure it will help. Anyway, if you already changed it could you please send it to me?:) of course if you are okay with that

DaShenZi721 commented 1 year ago

@DimGorr Hi! Do you solve the issue about information leakage from the repositories: https://github.com/uta-smile/RetroXpert/issues/15?

DimGorr commented 1 year ago

@DaShenZi721 Hi! not yet:(