westlake-repl / SaProt

[ICLR'24 spotlight] Saprot: Protein Language Model with Structural Alphabet
MIT License
271 stars 25 forks source link

The meaning of output size #20

Closed Yangqy-16 closed 3 months ago

Yangqy-16 commented 3 months ago

Thank you for your great job and the sharing of the repository! I wonder the meaning of the output size torch.Size([1, 11, 446]) in your example in README. I suppose '446' is the size of the vocabulary, but why it's different from '441' provided in your article? Moreover, do you provide codes/scripts to load batches/bulks of sequence into your model like ESM? Thank you very much!

LTEnjoy commented 3 months ago

Hi, thank you for being interested in our job!

Your understanding about the output size is right! The reason why the size of our vocabulary is 446 not 441 is because we additionally add some special tokens into the vocabulary like <cls>, <mask> and so on. You could check it out in vocab.txt provided in our checkpoint directory.

As our model is totally compatible with HuggignFace, you could load a batch of sequences through this way:

from transformers import EsmTokenizer, EsmForMaskedLM

model_path = "/your/path/to/SaProt_650M_AF2"
tokenizer = EsmTokenizer.from_pretrained(model_path)
model = EsmForMaskedLM.from_pretrained(model_path)

#################### Example ####################
device = "cuda"
model.to(device)

seqs = ["MdEvVpQpLrVyQdYaKv", "MdEvVpQpLrVyQdYaKv"]
tokens = tokenizer.tokenize(seq)
print(tokens)

inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs)
print(outputs.logits.shape)

"""
['Md', 'Ev', 'Vp', 'Qp', 'Lr', 'Vy', 'Qd', 'Ya', 'Kv']
torch.Size([2, 11, 446])
"""

Hope this could resolve your problem! :)

Yangqy-16 commented 3 months ago

Now I totally understand. Thank you very much!