evolutionaryscale / esm

Other
1.13k stars 123 forks source link

Limit AA alphabet #64

Closed neuwirtter closed 3 weeks ago

neuwirtter commented 1 month ago

Dear team,

I tried to limit the AA alphabet of the generated protein with this code but I still keep getting proteins with the AA I don't want to, could you please help me investigate where is the issue?

Thank you very much for your assistance.

from huggingface_hub import login
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
from esm.tokenization import get_model_tokenizers
from esm.utils.constants.models import ESM3_OPEN_SMALL
import argparse
parser = argparse.ArgumentParser()

#Parse arguments
parser.add_argument('--input', help="Starting sequence")
parser.add_argument('--output', help="Output file")
parser.add_argument('--num', help="How many sequences to generate")
#parser.add_argument('--temp', help="Temperature setting")
args = parser.parse_args()

from Bio import SeqIO
query = SeqIO.index(args.input, "fasta")
my_prompt = str(query[list(query.keys())[0]].seq)
masked_prompt = my_prompt.replace('A', '_')

# This will prompt you to get an API key from huggingface hub, make one with
# "Read" or "Arite" permission and copy it back here.
login()

# This will download the model weights and instantiate the model on your machine.
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cpu") # or "cpu"
# Get the tokenizers
tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
sequence_tokenizer = tokenizers.sequence

protein = ESMProtein(sequence=masked_prompt)
# Generate the sequence, then the structure. This will iteratively unmask the sequence track.
template_tokens = model.encode(protein)

# Get the token ID for 'A'
token_id_for_A = sequence_tokenizer.convert_tokens_to_ids('A')
print(f"Token ID for 'A': {token_id_for_A}")

# Generate a sequence without 'A'
generation_config_structure = GenerationConfig(
    track="structure",
    num_steps=8,
    invalid_ids=[token_id_for_A]
)
protein = model.generate(protein, generation_config_structure)

# Print the generated sequence without 'A'
print("Generated sequence (structure track) without 'A':", protein.sequence)

# Generate a sequence without 'A'
generation_config_structure = GenerationConfig(track="structure", num_steps=8, invalid_ids=[token_id_for_A])
protein = model.generate(protein, generation_config_structure)
# Then we can do a round trip design by inverse folding the sequence and recomputing the structure
outh = open(args.output, "w")
i = 0
while i < int(args.num):
    # Perform round-trip design by inverse folding the sequence and recomputing the structure
    protein.sequence = None
    generation_config_sequence = GenerationConfig(
        track="sequence",
        num_steps=8,
        invalid_ids=[token_id_for_A]
    )
    protein = model.generate(protein, generation_config_sequence)
    # Print the final protein sequence without 'A'
    print("Final generated sequence (sequence track) without 'A':", protein.sequence)
    outh.write(">Protein_" + str(i) + "\n" + str(protein.sequence) + "\n")
    i += 1
#protein.coordinates = None
#protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
#protein.to_pdb("./round_tripped.pdb")
ebetica commented 3 weeks ago

I think this should be fixed in #95