Open ekiefl opened 5 months ago
Hi, thanks for your interest in our models! :) I have to admit that I did not use XLNet for the last years. Is there a specific reason why you would not use ProtT5? (https://huggingface.co/Rostlab/prot_t5_xl_uniref50 ) It's clearly our best performing model so far and I would highly recommend to use this instead of any of the other models we trained (if there is no good reason to specifically look into XLNet)
Hi @mheinzinger thanks for your help.
I'm interested in generating protein sequences. Perhaps you could help me navigate my options.
I started with the provided notebook only because that's where users tend to go when they are trying to get something up and running. I didn't use ProtT5 because it is not an example in Generate/
.
Since you highly recommend ProtT5, I propose we remove XLNet generate script and replace it with one that both works and uses a recommended model. If you could help me modify the notebook, I'd be happy to make a PR.
Does this sound like a plan to you? If so, here's what I've done so far:
import torch
-from transformers import XLNetLMHeadModel, XLNetTokenizer,pipeline
+from transformers import T5Tokenizer, pipeline, T5ForConditionalGeneration
import re
import os
import requests
from tqdm.auto import tqdm
-tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)
+tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
-model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")
+# TypeError: T5ForConditionalGeneration.__init__() got an unexpected keyword argument 'do_lower_case'
+model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_uniref50")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()
sequences_Example = "A E T C Z A O"
sequences_Example = re.sub(r"[UZOB]", "<unk>", sequences_Example)
ids = tokenizer.encode(sequences_Example, add_special_tokens=False)
input_ids = torch.tensor(ids).unsqueeze(0).to(device)
max_length = 100
temperature = 1.0
k = 0
p = 0.9
repetition_penalty = 1.0
num_return_sequences = 3
output_ids = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
do_sample=True,
num_return_sequences=num_return_sequences,
)
output_sequences = [" ".join(" ".join(tokenizer.decode(output_id)).split()) for output_id in output_ids]
print('Generated Sequences\n')
for output_sequence in output_sequences:
print(output_sequence)
+# Output:
+# < p a d > A E T C P A < / s >
+# < p a d > A E T C P A < / s >
+# < p a d > A E T C R A < / s >
Since the T5 model supports conditional generation, I think it would be nice to see a few examples for how that works.
You are perfectly right, unfortunately, I did not find the time yet to update the examples. If you should send a PR, I would happily accept. Only thing to consider beforehand: T5 (and so ProtT5) are a bit different than the usual encoders (masked language modeling, BERT-style) and the usual decoders (conditional language modeling, gpt-style) as they combine both (encoder+decoder in one model). This is also reflected in the pre-training objective which is a mix of both (span-denoising; you mask certain spans of tokens in the input which gets fed to the encoder and the decoder is asked to regenerate the missing spans; with spans being potentially >1 token but for ProtT5 we sticked to single token masking. All that being said: you can probably simply create a loop which consecutively puts a mask token (or multiple) at the end of the sequence, ask ProtT5 to generate a token and then keep doing this until you get a "proper" sequence (however you define proper). What is probably a much better use-case for ProtT5 in protein design is inpainting: you have a protein for which you would like to generate new candidate sequences and you know some region relevant for binding/function; then you can mask this relevant region, give the masked sequence to ProtT5 and ask it to fill in the missing residues.
Hello,
Thanks for the great tool. I'm excited to use ProtTrans for generating protein sequences, but I'm getting an index error in the example notebook (https://github.com/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb).
The error occurs when running cell 12:
Here's the full traceback:
This has occurred on two different sets of hardware.
Thanks for taking a look.
Evan