Closed yxwvutt closed 3 months ago
Iterative function decoding is still in the works. However, you can get function annotations out of a given protein using forward_and_sample
:
import torch
from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESMProtein,
SamplingConfig,
SamplingTrackConfig,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain
# Initialize the client
client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda"))
# Load the protein
protein = ProteinChain.from_rcsb("1utn")
protein = ESMProtein.from_protein_chain(protein)
# Predict function
protein_tensor = client.encode(protein)
inference_output = client.forward_and_sample(
protein_tensor,
SamplingConfig(
sequence=SamplingTrackConfig(),
structure=SamplingTrackConfig(),
secondary_structure=SamplingTrackConfig(),
sasa=SamplingTrackConfig(),
function=SamplingTrackConfig(only_sample_masked_tokens=False),
),
)
protein_tensor_with_function = inference_output.protein_tensor
protein_with_function = client.decode(protein_tensor_with_function)
print(protein_with_function.function_annotations)
Thanks for sharing this script. I tried running it, but I get an issue at the decoding step: protein_with_function = client.decode(protein_tensor_with_function)
Any ideas on how to fix this? Thanks
File /opt/conda/lib/python3.11/site-packages/esm/utils/function/encode_decode.py:168, in decode_residue_annotation_tokens(residue_annotations_token_ids, residue_annotations_tokenizer, annotation_min_length, annotation_gap_merge_max)
166 for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
167 token_ids = residue_annotations_token_ids[:, depth]
--> 168 for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy():
169 label = residue_annotations_tokenizer.vocabulary[vocab_index]
170 if label not in [*residue_annotations_tokenizer.special_tokens, "
ValueError: not enough values to unpack (expected 2, got 1)
Can you install from source?
This issue was recently fixed here: https://github.com/evolutionaryscale/esm/pull/36, but not yet on pypi.
Installing from source fixed it! Thanks :)
Iterative function decoding is still in the works. However, you can get function annotations out of a given protein using
forward_and_sample
:import torch from esm.models.esm3 import ESM3 from esm.sdk.api import ( ESMProtein, SamplingConfig, SamplingTrackConfig, ) from esm.utils.constants.models import ESM3_OPEN_SMALL from esm.utils.structure.protein_chain import ProteinChain # Initialize the client client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda")) # Load the protein protein = ProteinChain.from_rcsb("1utn") protein = ESMProtein.from_protein_chain(protein) # Predict function protein_tensor = client.encode(protein) inference_output = client.forward_and_sample( protein_tensor, SamplingConfig( sequence=SamplingTrackConfig(), structure=SamplingTrackConfig(), secondary_structure=SamplingTrackConfig(), sasa=SamplingTrackConfig(), function=SamplingTrackConfig(only_sample_masked_tokens=False), ), ) protein_tensor_with_function = inference_output.protein_tensor protein_with_function = client.decode(protein_tensor_with_function) print(protein_with_function.function_annotations)
I have tried these codes, but the error raise at
protein_with_function = client.decode(protein_tensor_with_function)
state_dict = torch.load(
ValueError: SASA does not start with 0 corresponding to BOS token
I don't know why this error occured, as I merely copy and run these code,
Hello there, amazing work! I was reading through the ESM3 preprint and saw that there was some work done to measure the ESM3-open function prediction performance, so I wanted to test the function prediction out.
I was going through the generate.ipynb notebook and noticed that there is no tutorial on prediction of a sequence's function. When I attempt to change the GenerationConfig track to be "function," I get an error stating "Sampling only masked tokens is undefined for function tokens."
After looking through the esm classes, I found a SamplingTrackConfig class that by default has only_sample_masked_tokens: bool = True. This does not seem to be a class that can be passed as part of the model.generate() function as it looks like model.generate() will only accept ESMProtein and GenerationConfig as inputs:
Is there a way to allow for function prediction of an input protein sequence?