shamim-hussain / tgt

Triplet Graph Transformer
MIT License
31 stars 2 forks source link

How can I use it for downstream task #3

Closed 295825725 closed 7 months ago

295825725 commented 7 months ago

Hi, I want to ask how should I finetune this model so that I can use it for other tasks? like instead of use for gap predictor, I want to predict the solubility, what is the training progress for it? Should I start with pretraining the predictor or I just need to do the finetune stage?

shamim-hussain commented 7 months ago

We hypothesize that the transfer learning occurs in the form of the model's understanding of molecular geometry. So, in our molecular property prediction experiments (e.g., on MOLPCBA), we only transfer the pairwise distance predictor and train the task-specific predictor (e.g., solubility predictor) from scratch.

(We found that fine-tuning the gap predictor is only relevant when targeting other related quantum chemical properties (e.g., the tasks on QM9). On the other hand, the prediction of solubility is not directly related to the prediction of the HOMO-LUMO gap, so we should not expect a positive transfer of knowledge here. Previous models probably benefitted from pertaining to PCQM4Mv2 due to the models' indirect understanding of molecular geometry. In our case, the distance predictor directly learns to do that.)

The pairwise distances (predicted by our predictor) are fed as input features to the next stage (in your case, the solubility predictor). We draw multiple distance samples with dropout turned on from the distance predictor as a form of data augmentation.

In our experiments, we do not fine-tune the distance predictor but rather use it as a frozen feature extractor. These predicted distances result in better performance on downstream tasks than even RDKit-generated coordinates.

For more details, please refer to our paper arXiv-2402.04538

We will add implementation of fine-tuning on MOLPCBA soon, but for now, here is what you could look into -

  1. Download the (noRDKit) distance predictor weights. You can directly use this distance predictor to make distance predictions on your target dataset.
  2. The molecules/smiles are converted into graphs using the smiles2graph function from OGB (i.e., this is the "2D" graph). Once this graph is fed into the distance predictor it can output pairwise distances for that molecule.
  3. Make predictions in the training mode, i.e., with dropouts turned on. This would lead to multiple sets of distance matrices for the same input molecule when you feed it in multiple times. (This is similar to forming multiple conformations for a molecule). Save them.
  4. Load the distances and use them in your specific task as input features. Similar to how our HOMO-LUMO gap predictor uses them. You may even use architectures other than TGT/EGT.
  5. Since you would end up with multiple sets of distance matrices for the same molecule, you should combine the predictions of individual input distance matrices during inference. In the case of classification, you can take the average of the predicted probabilities for each individual distance matrix.