MolecularAI / REINVENT4

AI molecular design tool for de novo design, scaffold hopping, R-group replacement, linker design and molecule optimization.
Apache License 2.0
359 stars 89 forks source link

Troubleshooting Validation Loss Issues in REINVENT4 Transfer Learning #123

Closed kingljy0818 closed 3 months ago

kingljy0818 commented 3 months ago

Hi,

In the following REINVENT4 script, Chembl34_filtered.prior is a prior model obtained by training on the ChEMBL34 compound database, which contains approximately 2 million SMILES entries. Agro_Chemical_Train.smi is the SMILES format file used for transfer learning, with about 80,000 SMILES, and Agro_Chemical_Validation.smi is the validation set with approximately 8,000 SMILES. However, as shown in the figure below, the validation loss (light blue curve) continuously increases throughout the training process. I am unsure how to solve this issue and would greatly appreciate your help and guidance. I look forward to your response.


run_type = "transfer_learning" device = "cuda:0" # Use one GPU tb_logdir = "tb_TL" json_out_config = "json_transfer_learning.json"

[parameters] num_epochs = 120 # Increase number of training epochs save_every_n_epochs = 5 batch_size = 256 num_refs = 0 sample_batch_size = 512 learning_rate = 0.0005 seed = 42 gradient_clipping = 1.0

Regularization parameters

weight_decay = 0.001 # Reduce weight decay dropout = 0.25 # Reduce dropout rate

input_model_file = "Chembl34_filtered.prior" smiles_file = "Agro_Chemical_Train.smi" output_model_file = "Agro_Chemical.agent" validation_smiles_file = "Agro_Chemical_Validation.smi"

[parameters.lr_scheduler] type = "ReduceLROnPlateau" factor = 0.5 patience = 5 # Increase patience min_lr = 1e-6

log_interval = 10

Randomization option

randomization_type = "initial" # Options: "initial", "every_epoch"

[validation] smiles_file = "Agro_Chemical_Validation.smi"

start_time = "2024-06-05 00:00:00"


TL_2

halx commented 3 months ago

I see that you have a minimum around step 15. Beyond that step you are starting to overfit. This is really basic machine learning and is described in any relevant text book.