westlake-repl / SaProt

Saprot: Protein Language Model with Structural Alphabet (AA+3Di)
MIT License
354 stars 33 forks source link

1) sequence recovery (sequence design) code request 2) generating .mdb file 3) PEFT 4) best model selection after training #73

Open sj584 opened 2 days ago

sj584 commented 2 days ago

Hi,

I successfully ran the finetuning code using config/pretrain/saprot.py and config/Thermostability/saprot.py Then I newly got these questions

I would really appreciate it if you could answer these.




1. Could you share sequence recovery (or sequence design) code? I made it in my own way, but not sure whether it is correct

Pseudocode would be.

1) given these; initial_tokens = ['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv'] (initial sequence) input_tokens = ['##', 'Ev', '#p', 'Qp', 'L#', '#y', '#d', 'Ya', 'Kv'] (masked sequence in sequence subtoken) (##, #p, #y, #d)

2) the model predicts a single token (seq/structure token) solely from the masked token, the structure subtoken could be wrong. (ex. ## -> Gr, #p -> Gp, #y -> Gp, #d -> Sd)

3) Then only extract the sequence token from the predicted token and reconstruct it. (structure subtoken is same)

input_tokens   = ['##', 'Ev', '#p', 'Qp', 'L#', '#y', '#d', 'Ya', 'Kv'] recovered_tokens = ['G#', 'Ev', 'Gp', 'Qp', 'L#, 'Gy', 'Sd', 'Ya', 'Kv']





2. I also made a code to generate the .mdb file as dataset I checked that it runs ok. But not sure whether the id can be arbitrary or not. (ex. 550, 5500) I would appreciate it if you could verify this code compared to yours

Generating .mdb file

'''python import lmdb import json

Example data data = {   "550": {"description": "A0A0J6SSW7", "seq": "M#R#A#A#A#T#L#L#V#T#L#C#V#V#G#A#N#E#A#R#A#GfIwLe..."},   "5500": {"description": "A0A535NFD5", "seq": "AdAvRvEvAvLvRvAvSvGvHdPdFdVdEdAdPpGpEpAaAdFp..."},   # Add more entries here }

Open (or create) an LMDB environment env = lmdb.open("my_lmdb_file", map_size=1e9) # map_size is the maximum size (in bytes) of the DB with env.begin(write=True) as txn:   # Add the length of the dataset for.. return int(self._get("length")) in SaprotFoldseekDataset   length = len(data)   txn.put("length".encode("utf-8"), str(length).encode("utf-8"))   for key, value in data.items():     # Convert the value to a JSON string     value_json = json.dumps(value)     # Store key-value pairs in the database; keys must be bytes     txn.put(key.encode("utf-8"), value_json.encode("utf-8"))

Close the LMDB environment env.close() '''

Reading .mdb file

'''python env = lmdb.open("my_lmdb_file/", readonly=True)

with env.begin() as txn:   cursor = txn.cursor()   for key, value in cursor:     print(key, value) '''





3. I onced asked whether PEFT is possible and you kindly answered that it is there in SaprotBaseModel.py In the code, I could see that Lora can be used for downstream task.

In my case, I was hoping to use LoRA for MLM finetuning first in certain protein domain and then do further finetuning on downstream task.

I somehow made the code but I think no approaches like this were available previously. So I was asking your opinion. Whether it will be viable approaches or not.

So the steps will be 1) Load SaProt model weights 2) Use LoRA for MLM finetuning 3) Load (SaProt model weights + Lora MLM finetuning weights) 4) finetune on downstream task
5) Load (SaProt model weights + Lora MLM finetuning weights + Lora downstream finetuning weights) 6) Prediction on downstream task

Or simply downstream task can be done by getting the embeddings from the (SaProt model weights + Lora MLM finetuning weights) coz above mentioned steps are too complicated





4. When I ran the code using config/pretrain/saprot.py or config/pretrain/saprot.py It seems that only one model is saved after training If so, how can I know whether the saved model is the optimal model?

I could see that in Trainer, enable_checkpointing: false. Should I change it into True and keep track of the result with wandb and find the model?





Thank you for reading long inqueries. It will be very helpful to me :)

LTEnjoy commented 2 days ago

Hi,

Glad to see you digging into the code very much!

  1. Could you share sequence recovery (or sequence design) code?

Of course! I have uploaded an new model file named saprot_if_model.py, which is used for protein inverse folding. The overall pipeline is nearly the same as you described above and you could check the function predict for more details. Simply you could follow the example to easily do the inverse folding:

from model.saprot.saprot_if_model import SaProtIFModel

# Load model
config = {
    "config_path": "/your/path/to/SaProt_650M_AF2_inverse_folding", # Please download the weights from https://huggingface.co/westlake-repl/SaProt_650M_AF2_inverse_folding
    "load_pretrained": True,
}

device = "cuda"
model = SaProtIFModel(**config)
model = model.to(device)

aa_seq = "##########" # All masked amino acids will be predicted. You could also partially mask the amino acids.
struc_seq = "dddddddddd"

# Predict amino acids given the structure sequence
pred_aa_seq = model.predict(aa_seq, struc_seq)
print(pred_aa_seq)
  1. About the generation of .mdb file

Sorry I didn't see your previously proposed issue asking for the code for generating .mdb file. You could refer to this reply https://github.com/westlake-repl/SaProt/issues/72#issuecomment-2478976278 to generate your own .mdb dataset.

  1. Using LoRA for MLM training

I think it may not be necessary to first fine-tune SaProt using MLM function and then fine-tune it on the downstream task. In my opinion if you already have some labeled data you could directly fine-tune your model on this data and there is no need to do MLM pre-training at first. I guess the final performance should be comparable.

  1. The strategy for saving a checkpoint

I believe this issue https://github.com/westlake-repl/SaProt/issues/69 could resolve your question:)

Overall, thank you again for proposing such good questions! If you have any other questions, let me know and I'd love to help:)