Open limjiayi opened 11 months ago
Getting the same error with the T5 models.
from axtk.generation_utils import RegexLogitsProcessor, TokenHealingLogitsProcessor
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList
MODEL_ID = 't5-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, return_tensors="pt")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
prompt = 'What is the capital of France? '
max_tokens = 20
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
processors = []
healer = TokenHealingLogitsProcessor(input_ids[0], tokenizer)
healed_token_ids = healer.healed_token_ids
if len(healed_token_ids) > 0:
input_ids = input_ids[:, :-len(healed_token_ids)]
max_tokens += len(healed_token_ids)
processors.append(healer)
regex_processor = RegexLogitsProcessor(r'Paris|London|Berlin', prefix_length=len(prompt), stop_regex='', tokenizer=tokenizer)
processors.append(regex_processor)
procesors = LogitsProcessorList(processors)
output = model.generate(input_ids, logits_processor=processors, max_new_tokens=max_tokens)
tokenizer.batch_decode(output, skip_special_tokens=True)
tokenizer.vocab_size, model.config.vocab_size
# (32100, 32128)
Minimal reproducible example
Traceback
The mismatch is because the model has a different vocab size than the tokenizer.