agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.13k stars 153 forks source link

IndexError in generate protein notebook #152

Open ekiefl opened 5 months ago

ekiefl commented 5 months ago

Hello,

Thanks for the great tool. I'm excited to use ProtTrans for generating protein sequences, but I'm getting an index error in the example notebook (https://github.com/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb).

The error occurs when running cell 12:

output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
    )

Here's the full traceback:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[24], line 1
----> 1 output_ids = model.generate(
      2         input_ids=input_ids,
      3         max_length=max_length,
      4         temperature=temperature,
      5         top_k=k,
      6         top_p=p,
      7         repetition_penalty=repetition_penalty,
      8         do_sample=True,
      9         num_return_sequences=num_return_sequences,
     10     )

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py:1758](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py#line=1757), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1750     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1751         input_ids=input_ids,
   1752         expand_size=generation_config.num_return_sequences,
   1753         is_encoder_decoder=self.config.is_encoder_decoder,
   1754         **model_kwargs,
   1755     )
   1757     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1758     result = self._sample(
   1759         input_ids,
   1760         logits_processor=prepared_logits_processor,
   1761         logits_warper=prepared_logits_warper,
   1762         stopping_criteria=prepared_stopping_criteria,
   1763         generation_config=generation_config,
   1764         synced_gpus=synced_gpus,
   1765         streamer=streamer,
   1766         **model_kwargs,
   1767     )
   1769 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   1770     # 11. prepare logits warper
   1771     prepared_logits_warper = (
   1772         self._get_logits_warper(generation_config) if generation_config.do_sample else None
   1773     )

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py:2397](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py#line=2396), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2394 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2396 # forward pass to get next token
-> 2397 outputs = self(
   2398     **model_inputs,
   2399     return_dict=True,
   2400     output_attentions=output_attentions,
   2401     output_hidden_states=output_hidden_states,
   2402 )
   2404 if synced_gpus and this_peer_finished:
   2405     continue  # don't waste resources running the code we don't need

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1531), in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1540), in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py:1440](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py#line=1439), in XLNetLMHeadModel.forward(self, input_ids, attention_mask, mems, perm_mask, target_mapping, token_type_ids, input_mask, head_mask, inputs_embeds, labels, use_mems, output_attentions, output_hidden_states, return_dict, **kwargs)
   1370 r"""
   1371 labels (`torch.LongTensor` of shape `(batch_size, num_predict)`, *optional*):
   1372     Labels for masked language modeling. `num_predict` corresponds to `target_mapping.shape[1]`. If
   (...)
   1436 ... )  # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
   1437 ```"""
   1438 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 1440 transformer_outputs = self.transformer(
   1441     input_ids,
   1442     attention_mask=attention_mask,
   1443     mems=mems,
   1444     perm_mask=perm_mask,
   1445     target_mapping=target_mapping,
   1446     token_type_ids=token_type_ids,
   1447     input_mask=input_mask,
   1448     head_mask=head_mask,
   1449     inputs_embeds=inputs_embeds,
   1450     use_mems=use_mems,
   1451     output_attentions=output_attentions,
   1452     output_hidden_states=output_hidden_states,
   1453     return_dict=return_dict,
   1454     **kwargs,
   1455 )
   1457 logits = self.lm_loss(transformer_outputs[0])
   1459 loss = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1531), in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1540), in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py:1170](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py#line=1169), in XLNetModel.forward(self, input_ids, attention_mask, mems, perm_mask, target_mapping, token_type_ids, input_mask, head_mask, inputs_embeds, use_mems, output_attentions, output_hidden_states, return_dict, **kwargs)
   1168     word_emb_k = inputs_embeds
   1169 else:
-> 1170     word_emb_k = self.word_embedding(input_ids)
   1171 output_h = self.dropout(word_emb_k)
   1172 if target_mapping is not None:

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1531), in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1540), in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py#line=162), in Embedding.forward(self, input)
    162 def forward(self, input: Tensor) -> Tensor:
--> 163     return F.embedding(
    164         input, self.weight, self.padding_idx, self.max_norm,
    165         self.norm_type, self.scale_grad_by_freq, self.sparse)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/functional.py:2264](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/functional.py#line=2263), in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2258     # Note [embedding_renorm set_grad_enabled]
   2259     # XXX: equivalent to
   2260     # with torch.no_grad():
   2261     #   torch.embedding_renorm_
   2262     # remove once script supports set_grad_enabled
   2263     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2264 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

This has occurred on two different sets of hardware.

Thanks for taking a look.

Evan

mheinzinger commented 5 months ago

Hi, thanks for your interest in our models! :) I have to admit that I did not use XLNet for the last years. Is there a specific reason why you would not use ProtT5? (https://huggingface.co/Rostlab/prot_t5_xl_uniref50 ) It's clearly our best performing model so far and I would highly recommend to use this instead of any of the other models we trained (if there is no good reason to specifically look into XLNet)

ekiefl commented 5 months ago

Hi @mheinzinger thanks for your help.

I'm interested in generating protein sequences. Perhaps you could help me navigate my options.

I started with the provided notebook only because that's where users tend to go when they are trying to get something up and running. I didn't use ProtT5 because it is not an example in Generate/.

Since you highly recommend ProtT5, I propose we remove XLNet generate script and replace it with one that both works and uses a recommended model. If you could help me modify the notebook, I'd be happy to make a PR.

Does this sound like a plan to you? If so, here's what I've done so far:

 import torch
-from transformers import XLNetLMHeadModel, XLNetTokenizer,pipeline
+from transformers import T5Tokenizer, pipeline, T5ForConditionalGeneration
 import re
 import os
 import requests
 from tqdm.auto import tqdm

-tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)
+tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)

-model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")
+# TypeError: T5ForConditionalGeneration.__init__() got an unexpected keyword argument 'do_lower_case'
+model = T5ForConditionalGeneration.from_pretrained("Rostlab/prot_t5_xl_uniref50")

 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

 model = model.to(device)
 model = model.eval()

 sequences_Example = "A E T C Z A O"

 sequences_Example = re.sub(r"[UZOB]", "<unk>", sequences_Example)

 ids = tokenizer.encode(sequences_Example, add_special_tokens=False)

 input_ids = torch.tensor(ids).unsqueeze(0).to(device)

 max_length = 100
 temperature = 1.0
 k = 0
 p = 0.9
 repetition_penalty = 1.0
 num_return_sequences = 3

 output_ids = model.generate(
         input_ids=input_ids,
         max_length=max_length,
         temperature=temperature,
         top_k=k,
         top_p=p,
         repetition_penalty=repetition_penalty,
         do_sample=True,
         num_return_sequences=num_return_sequences,
     )

 output_sequences = [" ".join(" ".join(tokenizer.decode(output_id)).split()) for output_id in output_ids]

 print('Generated Sequences\n')
 for output_sequence in output_sequences:
   print(output_sequence)

+# Output:
+# < p a d > A E T C P A < / s >
+# < p a d > A E T C P A < / s >
+# < p a d > A E T C R A < / s >

Since the T5 model supports conditional generation, I think it would be nice to see a few examples for how that works.

mheinzinger commented 5 months ago

You are perfectly right, unfortunately, I did not find the time yet to update the examples. If you should send a PR, I would happily accept. Only thing to consider beforehand: T5 (and so ProtT5) are a bit different than the usual encoders (masked language modeling, BERT-style) and the usual decoders (conditional language modeling, gpt-style) as they combine both (encoder+decoder in one model). This is also reflected in the pre-training objective which is a mix of both (span-denoising; you mask certain spans of tokens in the input which gets fed to the encoder and the decoder is asked to regenerate the missing spans; with spans being potentially >1 token but for ProtT5 we sticked to single token masking. All that being said: you can probably simply create a loop which consecutively puts a mask token (or multiple) at the end of the sequence, ask ProtT5 to generate a token and then keep doing this until you get a "proper" sequence (however you define proper). What is probably a much better use-case for ProtT5 in protein design is inpainting: you have a protein for which you would like to generate new candidate sequences and you know some region relevant for binding/function; then you can mask this relevant region, give the masked sequence to ProtT5 and ask it to fill in the missing residues.