westlake-repl / SaProt

[ICLR'24 spotlight] Saprot: Protein Language Model with Structural Alphabet
MIT License
271 stars 25 forks source link

Getting protein embeddings #14

Closed memgonzales closed 5 months ago

memgonzales commented 5 months ago

Hello, thank you very much for this wonderful work!

I was just wondering if, given a protein sequence and its corresponding 3Di structure-sequence (obtained via Foldseek), is it possible to extract fixed-length protein embeddings using SaProt? Would there be a sample script for this particular task?

Thank you.

LTEnjoy commented 5 months ago

Hi, thank you for being interested in our work!

For your question, we have provided a simple function to extract a fixed-length protein embedding given a structure-sequence. The example is provided below:

from model.saprot.base import SaprotBaseModel
from transformers import EsmTokenizer

config = {
    "task": "base",
    "config_path": "/your/path/to/SaProt_650M_AF2",
    "load_pretrained": True,
}

model = SaprotBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])

device = "cuda"
model.to(device)

seq = "MdEvVpQpLrVyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)

inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

embeddings = model.get_hidden_states(inputs, reduction="mean")
print(embeddings[0].shape)

Hope this could resolve your question, and let me know if there are any other things I can help!

memgonzales commented 5 months ago

Thank you very much!

wangze09 commented 1 month ago

Hi, thank you for being interested in our work!

For your question, we have provided a simple function to extract a fixed-length protein embedding given a structure-sequence. The example is provided below:

from model.esm.base import EsmBaseModel
from transformers import EsmTokenizer

config = {
    "task": "base",
    "config_path": "/your/path/to/SaProt_650M_AF2",
    "load_pretrained": True,
}

model = EsmBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])

device = "cuda"
model.to(device)

seq = "MdEvVpQpLrVyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)

inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

embeddings = model.get_hidden_states(inputs, reduction="mean")
print(embeddings[0].shape)

Hope this could resolve your question, and let me know if there are any other things I can help!

Hello, thanks for your brilliant work! But I could not find the module "model.esm" when I run the example code to extract embeddings. Should I run the code in hugging face? Thank you!

LTEnjoy commented 1 month ago

Hi, recently we made some modifications on our files. The name "esm" was renamed as "saprot" and you could extract embeddings by simply changing the name. We highly recommend you dive into the .py file in case of code mismatching due to the updates!

wangze09 commented 1 month ago

Hi, recently we made some modifications on our files. The name "esm" was renamed as "saprot" and you could extract embeddings by simply changing the name. We highly recommend you dive into the .py file in case of code mismatching due to the updates!

ohhhh, Thanks a lot!