plainerman / DiffDock-Pocket

Implementation of DiffDock-Pocket: Diffusion for Pocket-Level Docking with Side Chain Flexibility
MIT License
15 stars 2 forks source link

The train loss cannot convergence #7

Open fengshikun opened 1 month ago

fengshikun commented 1 month ago

Hello, I've been attempting to train the score model using the command from the README file. However, I've noticed that the loss doesn't seem to converge. Could you please help me investigate which part might be going wrong?

image image
plainerman commented 1 month ago

We experienced similar behavior when training the model. This is why we optimized rmsd_lt2. You should see this increase. Do you?

fengshikun commented 1 month ago

Sorry, I just checked the logs and noticed that Val inference rmsds_lt2 consistently remains zero. Additionally, the validation loss shows fluctuating and abnormal values such as:

Epoch 19: Validation loss 42793415.1111  tr 171172360.1270   rot 1286.8611   tor 0.9821   sc_tor 0.9899
Epoch 20: Validation loss 330176498.0680  tr 1320697075.8095   rot 8896.0845   tor 0.9705   sc_tor 0.9875

I followed the command exactly as specified in the README file, so I suspect there might be a configuration issue or perhaps a bug in the code.

plainerman commented 1 month ago

It was common for us to have epochs with outliers and very big losses. But values should not be consistently this large.

In our run valinf_rmsds_lt2 only started to be promising after ~50 epochs. How long did you train for?

fengshikun commented 1 month ago

I have trained for approximately 100 epochs, and the latest results for Val inference rmsds_lt2 are consistently zero, as shown below:

Epoch 89: Val inference rmsds_lt2 0.000 rmsds_lt5 0.000 sc_rmsds_lt2 3.000 sc_rmsds_lt1 0.000, sc_rmsds_lt0.5 0.000 avg_improve 16.225 avg_worse 17.347  sc_rmsds_lt2_from_holo 3.000 sc_rmsds_lt1_from_holo 0.000, sc_rmsds_lt05_from_holo.5 0.000 sc_rmsds_avg_improvement_from_holo 15.128 sc_rmsds_avg_worsening_from_holo 19.187  
Storing best sc_rmsds_lt05_from_holo model
Run name:  big_score_model

Epoch 94: Val inference rmsds_lt2 0.000 rmsds_lt5 0.000 sc_rmsds_lt2 3.000 sc_rmsds_lt1 0.000, sc_rmsds_lt0.5 0.000 avg_improve 16.219 avg_worse 24.843  sc_rmsds_lt2_from_holo 2.000 sc_rmsds_lt1_from_holo 0.000, sc_rmsds_lt05_from_holo.5 0.000 sc_rmsds_avg_improvement_from_holo 16.306 sc_rmsds_avg_worsening_from_holo 18.885  
Storing best sc_rmsds_lt05_from_holo model
Run name:  big_score_model
fengshikun commented 1 month ago

Below is the complete training log file for your reference. big_score_model_resume.log

plainerman commented 1 month ago

Could you try --limit_complexes 100 and setting the train set to validation set?

i.e. --split_train data/splits/timesplit_no_lig_overlap_val_aligned --split_val data/splits/timesplit_no_lig_overlap_val_aligned --limit_complexes 100

And see if the problem persists? With this, we can see whether it can overfit on a small subset.

fengshikun commented 1 month ago

Could you try --limit_complexes 100 and setting the train set to validation set?

i.e. --split_train data/splits/timesplit_no_lig_overlap_val_aligned --split_val data/splits/timesplit_no_lig_overlap_val_aligned --limit_complexes 100

And see if the problem persists? With this, we can see whether it can overfit on a small subset.

Thank you. I'll try it out and will share the results later.

plainerman commented 1 month ago

You have to comment out this line for it to work

https://github.com/plainerman/DiffDock-Pocket/blob/98a15235eeec7561fc9f183823273452270487b8/datasets/pdbbind.py#L988

plainerman commented 1 month ago

Maybe related: https://github.com/plainerman/DiffDock-Pocket/issues/6 (which has been fixed). So maybe it is worth to pull again

fengshikun commented 1 month ago

Maybe related: #6 (which has been fixed). So maybe it is worth to pull again

Got it, thanks for the remainder

fengshikun commented 1 month ago

Maybe related: #6 (which has been fixed). So maybe it is worth to pull again

I have pulled the newest version of the codebase and trained the scoring model using only 100 complex structures. However, the loss continues to fluctuate and has not converged yet. The training command used is as follows:

python -u train.py --run_name big_score_model --test_sigma_intervals --log_dir workdir --lr 1e-3 --tr_sigma_min 0.1 --tr_sigma_max 5 --rot_sigma_min 0.03 --rot_sigma_max 1.55 --tor_sigma_min 0.03 --sidechain_tor_sigma_min 0.03 --batch_size 32 --ns 60 --nv 10 --num_conv_layers 6 --distance_embed_dim 64 --cross_distance_embed_dim 64 --sigma_embed_dim 64 --dynamic_max_cross --scheduler plateau --scale_by_sigma --dropout 0.1 --sampling_alpha 1 --sampling_beta 1 --remove_hs --c_alpha_max_neighbors 24 --atom_max_neighbors 8 --receptor_radius 15 --num_dataloader_workers 1 --cudnn_benchmark --rot_alpha 1 --rot_beta 1 --tor_alpha 1 --tor_beta 1 --val_inference_freq 5 --use_ema --scheduler_patience 30 --n_epochs 750 --all_atom --sh_lmax 1 --sh_lmax 1 --split_train data/splits/timesplit_no_lig_overlap_val_aligned --split_val data/splits/timesplit_no_lig_overlap_val_aligned --limit_complexes 100 --pocket_reduction --pocket_buffer 10 --flexible_sidechains --flexdist 3.5 --flexdist_distance_metric prism --protein_file protein_esmfold_aligned_tr_fix --compare_true_protein --conformer_match_sidechains --conformer_match_score exp --match_max_rmsd 2 --use_original_conformer_fallback --use_original_conformer

The complete training log is provided below: big_score_model.log

glukhove commented 4 weeks ago

@fengshikun hi, were you able to train the model?

fengshikun commented 2 weeks ago

@fengshikun hi, were you able to train the model?

The loss still cannot converge.