noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
994 stars 45 forks source link

Question about Source Code #117

Open Acatsama0871 opened 4 days ago

Acatsama0871 commented 4 days ago

Hello,

I would first thank you for open-sourcing such a well-designed and high-quality code base.

I am reading the source code, and I have a question about this part(integrations.transformers.py):

def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]:
    token_0 = tokenizer.encode("0")[-1]
    regular_tokens = []
    for token_idx in range(len(tokenizer)):
        if token_idx in tokenizer.all_special_ids:
            continue
        # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
        decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
        decoded_regular = tokenizer.decode([token_idx])
        is_word_start_token = len(decoded_after_0) > len(decoded_regular)
        regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
    return regular_tokens

Why the same token id is decoded twice here? And what does the "start word" mean in this context? thx

noamgat commented 3 days ago

We decode the token ID twice, once on its own, and once after the representation of 0, to understand whether its a start word token or not. Start word tokens will have an extra character when they are decoded after 0 compared to when they are not, and that is what we are checking. Some tokenizers expose this directly, but this method is tokenizer agnostic.