westlake-repl / SaProt

Saprot: Protein Language Model with Structural Alphabet (AA+3Di)
MIT License
353 stars 33 forks source link

How to fine tune locally with SaProtHub datasets? #72

Open dc2211 opened 2 days ago

dc2211 commented 2 days ago

Hi everyone,

I’m fairly new with huggingface, and I was wondering if it is possible to locally fine tune SaProt with the SaProtHub datasets, and how to call the models from there as well, rather than using Colab, since I’m getting runtime disconnected and constantly need to restart the process.

As an example I would like to give it a try to fine tune a model locally, similarly as it is done with the example for Thermostability, but with this dataset.

Also, how can I use locally the models from SaProtHub, like this? I understand the weights are not there, so probably is just more a documentation of the performance and config.

Many thanks!

LTEnjoy commented 2 days ago

Hi,

You indeed could load models and datasets from huggingface. But maybe it is more suitable for you to use the ColabSaprot interface while fine-tuning model with your local GPU. It's very simple as you only have to follow several steps from https://github.com/westlake-repl/SaprotHub/tree/main/local_server.

dc2211 commented 1 day ago

Thanks for the quick response. I managed to correctly connect to the local runtime. Still, I think there must be something that might not be suited for my setup.

In order to request the resources I need to train SaProt, I need to srun and request for the gpu I will use. When I do this, Im not in the login node anymore. So I go from being in http://login:12315/?token… to http://gpu:12315/?token…

My ideal case would be to directly get the .csv of a particular dataset from SaprotHub, generate the .mdp splits (as in the thermostability example), and fine-tune the model. Can this be done? The only part I am missing is the generation of the mdp files for LMDB. Any suggestions on this please?

Thanks

LTEnjoy commented 1 day ago

Sure. We provide some functions to generate the LMDB file, either from a dictionary or from a file.

You could check utils/generate_lmdb.py and refer to the function jsonl2lmdb. This function converts your dataset into an LMDB file. All you have to do is generating a .jsonl file from the .csv file. Each line of the .jsonl is a dictionary containing one sample, e.g.:

jsonl file:
{name: xxx, seq: AAAAAAA, label: 1}
{name: xxx, seq: BBBBBBB, label: 0}
{name: xxx, seq: CCCCCCC, label: 1}
...

Then you could call this function and generate a LMDB file. Note that you have to generate 3 .jsonl files for training, validation and test seperately.

dc2211 commented 1 day ago

Thanks a lot. It worked great, and I managed to fine-tune some models :)

one last thing before closing: from the examples I noticed that all are related to the mutation effect of a single substitution on a protein. I want to do 2 things

  1. Predicted the effect of combinatorial mutations (basically anything >2 substitutions) into a protein
  2. Predict the fitness (or thermostability from the fine-tuned model) for multiple independent protein sequences, with the idea to compare and rank them

I followed #51, and use the expression M1A:V3D:H8K… to predict the effect of multiple mutations into one protein, but the fitness values were really high (>30). Is this normal? Is there an actual limit for the predicted score? I had the feeling that the predicted score was the sum of all single point mutations scores, but Im probably missing something.

Many thanks again for all the help.

LTEnjoy commented 1 day ago

Hi,

Yes. If you input combinatorial mutations to the model, the predicted score is the sum of all single point mutations scores. Here we refer to the original paper ESM-1v that assumes the score is additive when multiple mutations exist.

Please note that SaProt tends to perform well for variants that have fewer mutation sites. Predicting mutational effect for multiple mutations is harder as the fitness landscape becomes more and more complicated and there might exist some epistasis.

dc2211 commented 20 hours ago

Thank you. Makes sense, specially the part about epistasis.

How about point 2: predicting fitness (or any other downstream task) for multiple protein sequences?

Is this possible? Thank you

LTEnjoy commented 18 hours ago

Yes, I think the point 2 is possible, as shown by the downstream task of thermostability. You just need to prepare some samples with labels of interest and use them to fine-tune Saprot. Then you could use the model to make inference and rank the candidates.

dc2211 commented 17 hours ago

And that is where I get confused. I already have the fine tuned model and I can do

mut_info = "V3A"
mut_value = model.predict_mut(seq, mut_info)
print(mut_value)

to predict the mutation effect of one (or more) substitutions. But what I would like to do is something more like

seqs = ["M#EvVpQpL#VyQdYaKv", "M#AvVpSpL#QyQdKvYa"]
for s in seqs:
    fit = model.predict(s)
    print(fit)

and will end up with one prediction per sequence that I would later use to rank.

Thanks again

LTEnjoy commented 17 hours ago

Yes. If you already have the fine-tuned model you can just follow the second way to predict the fitness. In fact the first way to predict mutational effect doesn't require additional training. It is in a zero-shot manner so you only need to load the original weight of Saprot and you can do the prediction (fine-tuned models cannot do this).

Now you have the fine-tuned model you can directly use it to make prediction. So your question is how you can load your model and do something in the second way?

dc2211 commented 17 hours ago

Yes, more like how to pass the whole sequence and get a prediction from the fine tuned, or zero shot model. Calling predict_mut is for single substitutions, and I’m not sure if there is another function to do this for a whole sequence.

seq1, score1
seq2, score2
LTEnjoy commented 16 hours ago

We don't provide a function call to directly get outputs from a fine-tuned model. However, you could easily load your model and manually get it. Here we use a regression model as an example:

from model.saprot.saprot_regression_model import SaprotRegressionModel

config = {
    "config_path": "/sujin/Models/SaProt/SaProt_35M_AF2",
    "from_checkpoint": "/Path/to/your/checkpoint",
}

model = SaprotRegressionModel(**config)
model.to("cuda")

seq = "M#EvVpQpL#VyQdYaKv"
inputs = model.tokenizer.batch_encode_plus([seq], return_tensors="pt", padding=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

outputs = model(inputs)
print(outputs)

The output is the final prediction of the model and you could use this score to rank your candidates.