Open junoriosity opened 1 year ago
I would like to use Scibert for iterated token generation. Here is my code:
import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM device = "cuda" tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') model = AutoModelForCausalLM.from_pretrained('allenai/scibert_scivocab_uncased').to(device) input_sequence = "Hello, I'm a language model," inputs = torch.as_tensor(tokenizer.encode(input_sequence)).unsqueeze(0).to(device) attention_mask = torch.as_tensor(tokenizer(input_sequence).attention_mask).unsqueeze(0).to(device) past_key_values = None count = 0 complete_token = [] with torch.no_grad(): while count < 10: count += 1 print("Iteration no.: " + str(count)) if count > 1: inputs = input_token print(inputs.to(device)) print(attention_mask) print(past_key_values[0][0].shape if past_key_values else None) model_out = model(input_ids=inputs.to(device), attention_mask=attention_mask, past_key_values=past_key_values) logits = model_out.logits[:, -1, :] past_key_values = model_out.past_key_values topk_values, topk_indices = torch.topk(logits, 5) log_probs = F.softmax(topk_values, dim=-1) inputs_in_topk = torch.multinomial(log_probs, num_samples=1, replacement=True) input_token = torch.gather(topk_indices, 1, inputs_in_topk) attention_mask = torch.concat((attention_mask, torch.tensor([[1]]).to(attention_mask.device)), dim=1) complete_token.append(input_token)
However, we have past_key_values = Null all the time. I tried this approach with other models and past_key_values is not null there. How can I make the iteration work here, such that we have the knowledge of the previous iteration?
past_key_values = Null
past_key_values
I would like to use Scibert for iterated token generation. Here is my code:
However, we have
past_key_values = Null
all the time. I tried this approach with other models andpast_key_values
is not null there. How can I make the iteration work here, such that we have the knowledge of the previous iteration?