westlake-repl / SaProt

[ICLR'24 spotlight] Saprot: Protein Language Model with Structural Alphabet
MIT License
323 stars 32 forks source link

per residue representations #28

Closed FlorianWieser1 closed 5 months ago

FlorianWieser1 commented 5 months ago

Dear authors,

Thank you for sharing this great work with us! I wonder if its possible to extract the per residue representations like with ESM-2?

import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

Thank you in advance!

LTEnjoy commented 5 months ago

Hello, thank you for your interest in our work!

We provide an easy way to extract the per residue representations. See the example below:

from model.esm.base import EsmBaseModel
from transformers import EsmTokenizer

config = {
    "task": "base",
    "config_path": "/sujin/Models/SaProt/SaProt_650M_AF2",
    "load_pretrained": True,
}

model = EsmBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])

device = "cuda"
model.to(device)

seq = "MdEvVpQpLrVyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)

inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    embeddings = model.get_hidden_states(inputs)
    print(embeddings[0].shape)

Hope this could resolve your problem :)

FlorianWieser1 commented 5 months ago

Thank you for the quick reply, works for me! :)