marcovzla / axtk

MIT License
0 stars 1 forks source link

Bug when `model.config.vocab_size` differs from `tokenizer.vocab_size` #1

Open limjiayi opened 11 months ago

limjiayi commented 11 months ago

Minimal reproducible example

from axtk.generation_utils import RegexLogitsProcessor, TokenHealingLogitsProcessor
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

MODEL_ID = 'facebook/opt-125m'
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

prompt = 'What is the capital of France? '
max_tokens = 20
input_ids = tokenizer(prompt, return_tensors='pt').input_ids

processors = []

healer = TokenHealingLogitsProcessor(input_ids[0], tokenizer)
healed_token_ids = healer.healed_token_ids
if len(healed_token_ids) > 0:
    input_ids = input_ids[:, :-len(healed_token_ids)]
    max_tokens += len(healed_token_ids)
    processors.append(healer)

regex_processor = RegexLogitsProcessor(r'Paris|London|Berlin', prefix_length=len(prompt), stop_regex='', tokenizer=tokenizer)
processors.append(regex_processor)

procesors = LogitsProcessorList(processors)
output = model.generate(input_ids, logits_processor=processors, max_new_tokens=max_tokens)
tokenizer.batch_decode(output, skip_special_tokens=True)

Traceback

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 2
      1 procesors = LogitsProcessorList(processors)
----> 2 output = model.generate(input_ids, logits_processor=processors, max_new_tokens=max_tokens)
      3 tokenizer.batch_decode(output, skip_special_tokens=True)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1607, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1590     return self.assisted_decoding(
   1591         input_ids,
   1592         assistant_model=assistant_model,
   (...)
   1603         **model_kwargs,
   1604     )
   1605 if generation_mode == GenerationMode.GREEDY_SEARCH:
   1606     # 11. run greedy search
-> 1607     return self.greedy_search(
   1608         input_ids,
   1609         logits_processor=logits_processor,
   1610         stopping_criteria=stopping_criteria,
   1611         pad_token_id=generation_config.pad_token_id,
   1612         eos_token_id=generation_config.eos_token_id,
   1613         output_scores=generation_config.output_scores,
   1614         return_dict_in_generate=generation_config.return_dict_in_generate,
   1615         synced_gpus=synced_gpus,
   1616         streamer=streamer,
   1617         **model_kwargs,
   1618     )
   1620 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
   1621     if not model_kwargs["use_cache"]:

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2468, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2465 next_token_logits = outputs.logits[:, -1, :]
   2467 # pre-process distribution
-> 2468 next_tokens_scores = logits_processor(input_ids, next_token_logits)
   2470 # Store scores, attentions and hidden_states when required
   2471 if return_dict_in_generate:

File /opt/conda/lib/python3.10/site-packages/transformers/generation/logits_process.py:97, in LogitsProcessorList.__call__(self, input_ids, scores, **kwargs)
     95         scores = processor(input_ids, scores, **kwargs)
     96     else:
---> 97         scores = processor(input_ids, scores)
     98 return scores

File /opt/conda/lib/python3.10/site-packages/axtk/generation_utils/logits_processors/token_healing_logits_processor.py:90, in TokenHealingLogitsProcessor.__call__(self, input_ids, scores)
     87 scores = to_tensor(scores)
     89 # make only allowed tokens possible
---> 90 return scores + self.token_masks[self.num_extensions-1]

RuntimeError: The size of tensor a (50272) must match the size of tensor b (50265) at non-singleton dimension 1

The mismatch is because the model has a different vocab size than the tokenizer.

tokenizer.vocab_size, model.config.vocab_size
# (50265, 50272)
limjiayi commented 11 months ago

Getting the same error with the T5 models.

from axtk.generation_utils import RegexLogitsProcessor, TokenHealingLogitsProcessor
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList

MODEL_ID = 't5-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, return_tensors="pt")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

prompt = 'What is the capital of France? '
max_tokens = 20
input_ids = tokenizer(prompt, return_tensors='pt').input_ids

processors = []

healer = TokenHealingLogitsProcessor(input_ids[0], tokenizer)
healed_token_ids = healer.healed_token_ids
if len(healed_token_ids) > 0:
    input_ids = input_ids[:, :-len(healed_token_ids)]
    max_tokens += len(healed_token_ids)
    processors.append(healer)

regex_processor = RegexLogitsProcessor(r'Paris|London|Berlin', prefix_length=len(prompt), stop_regex='', tokenizer=tokenizer)
processors.append(regex_processor)

procesors = LogitsProcessorList(processors)
output = model.generate(input_ids, logits_processor=processors, max_new_tokens=max_tokens)
tokenizer.batch_decode(output, skip_special_tokens=True)
tokenizer.vocab_size, model.config.vocab_size
# (32100, 32128)