westlake-repl / SaProt

[ICLR'24 spotlight] Saprot: Protein Language Model with Structural Alphabet
MIT License
271 stars 25 forks source link

Output dimension #42

Closed douzhuang closed 6 days ago

douzhuang commented 1 week ago

"Hello, I would like to ask if the model output is only a vector of torch.Size([1, ***, 446]), is there a pooled vector available?"

LTEnjoy commented 1 week ago

Hello, thank you for your interest in our work!

We provide an easy way to extract the per residue representations. See the example below:

from model.saprot.base import SaprotBaseModel
from transformers import EsmTokenizer

config = {
    "task": "base",
    "config_path": "/sujin/Models/SaProt/SaProt_650M_AF2",
    "load_pretrained": True,
}

model = SaprotBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])

device = "cuda"
model.to(device)

seq = "MdEvVpQpLrVyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)

inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    embeddings = model.get_hidden_states(inputs, reduction="mean")
    print(embeddings[0].shape)

Hope this could resolve your problem :)