evolutionaryscale / esm

Other
1.21k stars 136 forks source link

Sequence Function Prediction #24

Closed yxwvutt closed 3 months ago

yxwvutt commented 3 months ago

Hello there, amazing work! I was reading through the ESM3 preprint and saw that there was some work done to measure the ESM3-open function prediction performance, so I wanted to test the function prediction out.

I was going through the generate.ipynb notebook and noticed that there is no tutorial on prediction of a sequence's function. When I attempt to change the GenerationConfig track to be "function," I get an error stating "Sampling only masked tokens is undefined for function tokens."

image

After looking through the esm classes, I found a SamplingTrackConfig class that by default has only_sample_masked_tokens: bool = True. This does not seem to be a class that can be passed as part of the model.generate() function as it looks like model.generate() will only accept ESMProtein and GenerationConfig as inputs:

def generate(input: ProteinType, config: GenerationConfig) -> ProteinType

Is there a way to allow for function prediction of an input protein sequence?

santiag0m commented 3 months ago

Iterative function decoding is still in the works. However, you can get function annotations out of a given protein using forward_and_sample:

import torch

from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain

# Initialize the client
client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda"))

# Load the protein
protein = ProteinChain.from_rcsb("1utn")
protein = ESMProtein.from_protein_chain(protein)

# Predict function
protein_tensor = client.encode(protein)
inference_output = client.forward_and_sample(
    protein_tensor,
    SamplingConfig(
        sequence=SamplingTrackConfig(),
        structure=SamplingTrackConfig(),
        secondary_structure=SamplingTrackConfig(),
        sasa=SamplingTrackConfig(),
        function=SamplingTrackConfig(only_sample_masked_tokens=False),
    ),
)
protein_tensor_with_function = inference_output.protein_tensor
protein_with_function = client.decode(protein_tensor_with_function)
print(protein_with_function.function_annotations)
HaiLunP commented 3 months ago

Thanks for sharing this script. I tried running it, but I get an issue at the decoding step: protein_with_function = client.decode(protein_tensor_with_function)

Any ideas on how to fix this? Thanks

File /opt/conda/lib/python3.11/site-packages/esm/utils/function/encode_decode.py:168, in decode_residue_annotation_tokens(residue_annotations_token_ids, residue_annotations_tokenizer, annotation_min_length, annotation_gap_merge_max) 166 for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS): 167 token_ids = residue_annotations_token_ids[:, depth] --> 168 for loc, vocab_index in torch.nonzero(token_ids).cpu().numpy(): 169 label = residue_annotations_tokenizer.vocabulary[vocab_index] 170 if label not in [*residue_annotations_tokenizer.special_tokens, ""]:

ValueError: not enough values to unpack (expected 2, got 1)

santiag0m commented 3 months ago

Can you install from source?

This issue was recently fixed here: https://github.com/evolutionaryscale/esm/pull/36, but not yet on pypi.

HaiLunP commented 3 months ago

Installing from source fixed it! Thanks :)

HongxiangXu commented 1 month ago

Iterative function decoding is still in the works. However, you can get function annotations out of a given protein using forward_and_sample:

import torch

from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain

# Initialize the client
client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda"))

# Load the protein
protein = ProteinChain.from_rcsb("1utn")
protein = ESMProtein.from_protein_chain(protein)

# Predict function
protein_tensor = client.encode(protein)
inference_output = client.forward_and_sample(
    protein_tensor,
    SamplingConfig(
        sequence=SamplingTrackConfig(),
        structure=SamplingTrackConfig(),
        secondary_structure=SamplingTrackConfig(),
        sasa=SamplingTrackConfig(),
        function=SamplingTrackConfig(only_sample_masked_tokens=False),
    ),
)
protein_tensor_with_function = inference_output.protein_tensor
protein_with_function = client.decode(protein_tensor_with_function)
print(protein_with_function.function_annotations)

I have tried these codes, but the error raise at protein_with_function = client.decode(protein_tensor_with_function) state_dict = torch.load( ValueError: SASA does not start with 0 corresponding to BOS token I don't know why this error occured, as I merely copy and run these code,