Closed mircare closed 2 years ago
Absolutely the right place to ask/comment, don't worry. The error you see is only related to running a half-precision model on CPU. This is currently not supported (yet?). If you execute the same code on google colab with a GPU, you should not see any error. You could "fix" the error by casting the model to full-precision (model=model.float()), or (recommended!) you run it on GPU.
We will adjust our repo accordingly to make this point more clear. Thanks for the heads-up!
Best, Michael
Thank you for the fast reply @mheinzinger. Indeed, moving model and input to the GPU solved it. Here is the updated code (there was a typo in line 21):
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16).to(device)
sequences_Example = ["A E T C Z A O","S K T Z P"]
seqs = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences_Example]
ids = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)
with torch.no_grad():
embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
emb_0 = embedding_repr.last_hidden_state[0,:6]
emb_1 = embedding_repr.last_hidden_state[1,:4]
Let me add that I am a big fan of this project. Keep it up!
I'd like to let you know that the instructions at https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc doesn't seem to work. I get the same error in 2 different python environment, even installing just PyTorch and the transformers library.
This is the code I run :
This is the error I get:
I am unsure whether this is the most appropriate place, please let me know otherwise. Please also note that I assign the output of re to 'seqs' instead of 'sequences_Example' (as in the example on HuggingFace).