Closed FlorianWieser1 closed 5 months ago
Hello, thank you for your interest in our work!
We provide an easy way to extract the per residue representations. See the example below:
from model.esm.base import EsmBaseModel
from transformers import EsmTokenizer
config = {
"task": "base",
"config_path": "/sujin/Models/SaProt/SaProt_650M_AF2",
"load_pretrained": True,
}
model = EsmBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])
device = "cuda"
model.to(device)
seq = "MdEvVpQpLrVyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)
inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
embeddings = model.get_hidden_states(inputs)
print(embeddings[0].shape)
Hope this could resolve your problem :)
Thank you for the quick reply, works for me! :)
Dear authors,
Thank you for sharing this great work with us! I wonder if its possible to extract the per residue representations like with ESM-2?
Thank you in advance!