datamol-io / safe

A single model for all your molecular design tasks
https://safe-docs.datamol.io/
Apache License 2.0
79 stars 7 forks source link

Grad Norm and SAFE encoding Misunderstanding #49

Closed Anri-Lombard closed 1 month ago

Anri-Lombard commented 1 month ago

When training a model on a different dataset, in this case (https://huggingface.co/datasets/sagawa/ZINC-canonicalized - somewhat larger than MOSES and quite a bit smaller than SAFE-GPT), the perplexity ends up very bad:

{
    "epoch": 1.0,
    "eval_runtime": 3393.0287,
    "eval_samples_per_second": 677.64,
    "eval_steps_per_second": 84.705,
    "perplexity": Infinity,
    "total_flos": 5.026540798656038e+16,
    "train_loss": 0.6743136047124862,
    "train_runtime": 6681.5609,
    "train_samples_per_second": 383.144,
    "train_steps_per_second": 2.993
}

Looking into it further I discovered the grad_norm is very large despite explicitly setting max_grad_norm:

"log_history": [
    {
      "epoch": 5e-05,
      "grad_norm": 57.65608215332031,
      "learning_rate": 5.0000000000000004e-08,
      "loss": 7.5185,
      "step": 1
    },
    {
      "epoch": 0.025,
      "grad_norm": 14.687941551208496,
      "learning_rate": 2.5e-05,
      "loss": 2.2887,
      "step": 500
    },
    {
      "epoch": 0.05,
      "grad_norm": 3.8548357486724854,
      "learning_rate": 5e-05,
      "loss": 1.0481,
      "step": 1000
    },
    {
      "epoch": 0.075,
      "grad_norm": 3.50759220123291,
      "learning_rate": 7.500000000000001e-05,
      "loss": 0.8689,
      "step": 1500
    },

The model then does not generate any valid molecules and seems to overfit: training_convergence

I tried adjusting the library myself and realised transformers set grad_norm to 1.0 by default, which made sense when I replicated your small model results since it stayed between 0 and 1 throughout and gave good results at the end.

Do you have a solution in mind? It might be that the default is ignored when doing warmup steps but I haven't found any evidence for this reading through the Trainer code (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py).

For more context this is a 50M parameter model with a learning rate of 1e-4 and the dataset is ~20M Zinc molecules. Do you have intuition what the problem might be?

maclandrol commented 1 month ago

Hi @Anri-Lombard a few questions:

What happens when you use the same small model on this dataset ?

Anri-Lombard commented 1 month ago

Hi @maclandrol, thank you for the quick response.

To address your last question first, no: I currently have access to 1 80GB A100, so am not training with multiple GPUs (thus not trying to replicate your larger model with the 1.1B molecule dataset).

Thank you for clarifying the gradient norm; that's how I intuited it, although was unsure. It explains the results of more iterations I tried 🙏.

To plot the validation loss I need to alter the safe library since the code currently plots perplexity only at the end (and there is no flag for intermediate recording of validation unless I'm mistaken?). The dataset mentioned is in smiles format, but I passed in the is_tokenized False flag; I did this for the small model when training on MOSES as well. I'm retraining the small model on this dataset, that is a great suggestion.

My intuition on tokenizers could be stronger, but I suspected since this is a 20M molecule zinc dataset and you trained the original tokenizer on 1.1B molecules, of which a large subset was zinc, retraining the tokenizer for a smaller zinc subset won't change the results?

Would you mind keeping the issue open for the time being? I can come back to record my findings for others if they happen upon the same situation once training is done.

(For context, the batch size I used was 64 with 2 steps gradient accumulation - just to address your second point)

maclandrol commented 1 month ago

I think you can potentially increase your batch size a bit. Also if your sequences are not very long after tokenization, try to reduce the model max_length for positional encoding. 1024 is bit high in most cases.

I would really suggest some light hyperparameter tuning here. You normally should be able to plot the validation loss if you use wandb. Just make sure that your dataset is a datasetdict and you have a key called validation or test (if validation is not found). You will need to provide the eval frequency, etc to the cli. For example, you would need these:

    --eval_strategy epoch \
    --eval_accumulation_steps 1000 \
    --do_train True \
    --do_eval True \

My intuition on tokenizers could be stronger, but I suspected since this is a 20M molecule zinc dataset and you trained the original tokenizer on 1.1B molecules, of which a large subset was zinc, retraining the tokenizer for a smaller zinc subset won't change the results?

You are right that the tokenizer should work for both SAFE and SMILES strings and a good model should just ignore any tokens that is not in your training data. There is an argument that can be made about this being wasteful, as you can likely reduce the vocab and by that also reduce the model size and thus training time.

Also, just to be sure, if you train of the 20M SMILES dataset without converting to SAFE first, you will get a SMILES model, as we don't automatically convert to allow reusing the same training code for any molecular line representation the user wants to use.

Anri-Lombard commented 1 month ago

Also, just to be sure, if you train of the 20M SMILES dataset without converting to SAFE first, you will get a SMILES model, as we don't automatically convert to allow reusing the same training code for any molecular line representation the user wants to use.

Now this is extremely interesting... This could also explain the performance of my model to an extend, since I retrained the 50M model with different hyperparameters, then it resulted in about 50% valid molecules (although it is a SMILES model). In contrast the small 20M SAFE model had the same results as claimed in your paper: Screenshot 2024-07-25 at 20 58 49

Mentioning that it does not is extremely helpful, and I'm slightly embarrassed that I was balancing on that assumption and am realising in real time tokenizer.py does not include the is_tokenized flag to automatically convert to SAFE. This might explain why larger datasets are difficult to train on.

Fantastic; thank you! 🙂

Anri-Lombard commented 1 month ago

When running the training with wandb, I notice it does not record the eval loss: Screenshot 2024-07-27 at 17 44 52

I did training with a from scratch library I built on another architecture similar to the safe-mol library, which does plot it.

These are my settings:

safe-train --config $config_path \
  --tokenizer $tokenizer_path \
  --dataset $dataset_path \
  --text_column "SAFE" \
  --optim "adamw_torch" \
  --learning_rate 5e-4 \
  --per_device_train_batch_size 32 \
  --per_device_eval_batch_size 32 \
  --gradient_accumulation_steps 2 \
  --report_to "wandb" \
  --warmup_steps 20000 \
  --logging_first_step True \
  --logging_steps 500 \
  --eval_accumulation_steps 8 \
  --save_steps 100 \
  --eval_steps 500 \
  --eval_strategy "steps" \
  --wandb_project "SAFE_small" \
  --num_train_epochs 10 \
  --save_total_limit 1 \
  --output_dir $output_dir \
  --overwrite_output_dir True \
  --do_train True \
  --load_best_model_at_end True \
  --do_eval True \
  --save_safetensors True \
  --gradient_checkpointing True \
  --num_train_epochs 10

I'm curious if the library is giving this problem to you as well with the newest transformer and accelerator versions?

This explains the infinitiy value for perplexity perhaps. Referece:

if training_args.do_eval:
        logger.info("*** Evaluate ***")
        results = trainer.evaluate()
        try:
            perplexity = math.exp(results["eval_loss"])
        except Exception as e:
            logger.error(e)
            perplexity = float("inf")
        results.update({"perplexity": perplexity})
        if trainer.is_world_process_zero():
            trainer.log_metrics("eval", results)
            trainer.save_metrics("eval", results)

I'll attempt to find the error if it is actually one for a PR if needed.

(Testing this on the smaller model before moving to the larger model)

maclandrol commented 1 month ago

I suspect that not getting eval_loss is because you are not specifying the label_names. The SAFE model actually has additional heads to predict molecular properties that are not used when not specified. The transformers library can be confused by this additional label.

I have fixed the default behaviour in #53, please let me know if you get the eval_loss with your setup now.