noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.42k stars 65 forks source link

Initializtion is slow #75

Closed turboderp closed 5 months ago

turboderp commented 7 months ago

Hi. So I'm a bit confused by the contribution guidelines and don't want to submit an unsolicited PR.

But I was playing around with this, and the initialization seemed quite slow. Especially with Qwen models that have a 150k vocabulary, it takes over a minute to build the prefix trie (most of it spent in JsonFreetextTokenCache.freeze), but even with smaller vocabularies it would take upwards of 10 seconds.

So I made some optimizations here. Would you like me to submit a PR?

Briefly, the changes are:

First, ExLlamaV2 integration reads the vocabulary straight from the ExLlamaV2Tokenizer instead of calling decode() on every token. This is somewhat faster, especially using the HF Tokenizer with Qwen, where the decode method is surprisingly expensive. This also fixes some bugs, I think?

token_0 = tokenizer.encode("0")[0]  # Returns multiple tokens if "0" encodes to more than one token
decoded_after_0 = tokenizer.decode(tensor_after_0)[1:]  # Seems to assume "0" encodes to one token
decoded_regular = tokenizer.decode(token_0)  # Always returns "0"
is_word_start_token = len(decoded_after_0) > len(decoded_regular)  # Considers all tokens that decode to more than one character to start a new word

So the output isn't exactly the same, but I assume it's more correct, since tokens will have is_word_start_token == True precisely when they are word start tokens.

Second change is to JsonFreetextTokenCache which now constructs the cache using intersections on sets of ints, and avoids having to convert back to token IDs at the end.

I've tested it on Mistral, Llama and Qwen and confirmed that the resulting cache is identical, except for:

self.token_str_to_num[token_str] = token_int

And as a result, wherever " is a valid string, only token 28739 would be considered a valid token ID. So I think (?) it's more correct to allow both tokens in that case.

In any case, it does seem to still work in my tests, and it initializes 10-20x faster.

noamgat commented 7 months ago

There is a WIP branch for speeding this up, it's almost ready, and it addresses almost everything you raised here. There's one bug left to solve and then I'll merge.

On Sun, Feb 18, 2024, 22:28 turboderp @.***> wrote:

Hi. So I'm a bit confused by the contribution guidelines and don't want to submit an unsolicited PR.

But I was playing around with this, and the initialization seemed quite slow. Especially with Qwen models that have a 150k vocabulary, it takes over a minute to build the prefix trie (most of it spent in JsonFreetextTokenCache.freeze), but even with smaller vocabularies it would take upwards of 10 seconds.

So I made some optimizations here https://github.com/noamgat/lm-format-enforcer/compare/main...turboderp:lm-format-enforcer:main. Would you like me to submit a PR?

Briefly, the changes are:

First, ExLlamaV2 integration reads the vocabulary straight from the ExLlamaV2Tokenizer instead of calling decode() on every token. This is somewhat faster, especially using the HF Tokenizer with Qwen, where the decode method is surprisingly expensive. This also fixes some bugs, I think?

token_0 = tokenizer.encode("0")[0] # Returns multiple tokens if "0" encodes to more than one token

decoded_after_0 = tokenizer.decode(tensor_after_0)[1:] # Seems to assume "0" encodes to one tokendecoded_regular = tokenizer.decode(token_0) # Always returns "0"is_word_start_token = len(decoded_after_0) > len(decoded_regular) # Considers all tokens that decode to more than one character to start a new word

So the output isn't exactly the same, but I assume it's more correct, since tokens will have is_word_start_token == True precisely when they are word start tokens.

Second change is to JsonFreetextTokenCache which now constructs the cache using intersections on sets of ints, and avoids having to convert back to token IDs at the end.

I've tested it on Mistral, Llama and Qwen and confirmed that the resulting cache is identical, except for:

  • The resulting tuples are sorted by token ID instead of by token text. I couldn't see anywhere this would matter, though. (?)
  • There are duplicate tokens in most models. For example in Mistral, token 37 is <0x22>, which is an ASCII double quote, while token 28739 is \". The way the cache was built before, only the last of these tokens was considered:

self.token_str_to_num[token_str] = token_int

And as a result, wherever " is a valid string, only token 28739 would be considered a valid token ID. So I think (?) it's more correct to allow both tokens in that case.

In any case, it does seem to still work in my tests, and it initializes 10-20x faster.

— Reply to this email directly, view it on GitHub https://github.com/noamgat/lm-format-enforcer/issues/75, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKFA2F3NVRINZGQ2XOM4SLYUJP73AVCNFSM6AAAAABDONOGRGVHI2DSMVQWIX3LMV43ASLTON2WKOZSGE2DCMJXGEYDMOA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

noamgat commented 7 months ago

v0.9.0 was just released with the improvement. Can you update and check if the performance is now better?

turboderp commented 7 months ago

It is a lot better, yes, and somewhat faster than my approach. It still has the bugs in the ExLlamaV2 integration, and it's unusably slow for Qwen because it calls tokenizer.decode a lot, which for HF Tokenizer's Tiktoken implementation is amazingly inefficient for some reason.

I tested it on a few different models, and here is the time it takes to call ExLlamaV2TokenEnforcerFilter with a JSON schema parser:

model 0.8.3 0.9.0 mine 0.9.0 + exl2 change
Mistral 1.107 s 0.315 s 0.294 s 0.151 s
Llama2 1.114 s 0.313 s 0.297 s 0.146 s
Orion 3.034 s 0.802 s 0.821 s 0.373 s
Deepseek 64.471 s 57.910 s 0.693 s 0.150 s
Qwen 10+ min 10+ min 2.594 s 0.595 s

All seem to work correctly (except for the Qwen models that I lost patience with). The initialization of JsonFreetextTokenCache is faster in 0.9.0 than my set-based approach but I would still update the ExLlamaV2 integration.

It's a small and self-contained change:

def _build_regular_tokens_list(tokenizer: ExLlamaV2Tokenizer) -> List[Tuple[int, str, bool]]:
    vocab_size = tokenizer.tokenizer.vocab_size()
    all_special_ids = set(tokenizer.extended_id_to_piece.keys())
    all_special_ids.update({ tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id })
    id_to_piece = tokenizer.get_id_to_piece_list()
    regular_tokens = []
    for token_idx in range(vocab_size):
        if token_idx in all_special_ids:
            continue
        decoded = id_to_piece[token_idx]
        is_word_start_token = len(decoded) > 0 and decoded[0] == " "
        regular_tokens.append((token_idx, decoded, is_word_start_token))
    return regular_tokens

Last column above is 0.9.0 with this change applied.

turboderp commented 7 months ago

Any news on this?

noamgat commented 7 months ago

The reason for the usage of decode, is that it is the only way (as far as I know) to know which token is a start word. In most tokenizers the leading space does not appear in this mapping, but we need it to build the correct prefix tree. Is there a solution for this?

On Fri, Feb 23, 2024 at 5:13 PM turboderp @.***> wrote:

Any news on this?

— Reply to this email directly, view it on GitHub https://github.com/noamgat/lm-format-enforcer/issues/75#issuecomment-1961508852, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKFA2EF3APTUJRF7AV4HF3YVCWZ7AVCNFSM6AAAAABDONOGRGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNRRGUYDQOBVGI . You are receiving this because you commented.Message ID: @.***>

turboderp commented 7 months ago

The data is usually available in the tokenizer model. For instance with HF Tokenizers, the id_to_token() function returns the raw string for each token, including the leading space represented either by a space or by some placeholder character like Ġ or (which you can determine by encoding a single space character.) Or you could simply have a list of common options and add to it if a new model shows up that uses a different character. It's unlikely that any solution would really be future-proof anyway.

Anyway, the ExLlama tokenizer creates the id-to-piece mapping for this reason, with a bunch of logic to handle various formats, and producing a list with a common format. Here are the contents of tokenizer.get_id_to_piece_list()[2000:2010] for a few models:

Mistral:

2000: '####'
2001: 'public'
2002: '[]'
2003: ' room'
2004: 'len'
2005: ' family'
2006: 'por'
2007: ' program'
2008: ' hist'
2009: ' mus'

Llama:

2000: ' called'
2001: 'Item'
2002: 'ura'
2003: 'vec'
2004: 'eme'
2005: ' della'
2006: 'embre'
2007: 'urg'
2008: 'Se'
2009: ' request'

Orion:

2000: '你就'
2001: 'way'
2002: '完成'
2003: '最近'
2004: 'dition'
2005: ' quick'
2006: '13'
2007: '介绍'
2008: '60'
2009: ' women'

Gemma:

2000: 'main'
2001: ' last'
2002: 'ida'
2003: ' water'
2004: ' must'
2005: ' But'
2006: 'ux'
2007: ' ear'
2008: ' cas'
2009: ' Ad'

Word start tokens are then simply the ones that begin with a space, and a sequence can be decoded as just the concatenation of the respective pieces for each token ID. The ExLlama tokenizer also needs this to build a trie (used for token healing and other stuff), which is the same data structure as the TokenizerPrefixTree in lm-format-enforcer, just with a slightly different format. In principle it could be translated, but that probably wouldn't be faster than constructing it again.

One place where this all breaks a bit is with multi-token characters, such as UTF-8 emitted by the model using byte fallback tokens. I've also had issues with how Chinese characters are encoded in Tiktoken (seems to be almost UTF-8 but not quite?). But none of that really pertains to the leading space, and it would be an issue in any case if you're assuming a token always represents at least one character (as I think you probably have to for this kind of character-based constrained sampling.)

noamgat commented 7 months ago

Thanks for the detailed response! I would accept a PR if it would include the following:

  1. Removal of the changes to tokenizerprefixtree.py that were made redundant by the other improvements that were made.
  2. Adding a use_piece_id_mapping flag (could be default=True) that decides which method to use
  3. Adding a unit test that checks that for popular tokenizers, the resulting list is the same.

The unit test could lazy-import exllamav2 and skip if it doesn't exist to avoid optional dependency and CI issues.

turboderp commented 7 months ago

Yes, I'm only suggesting the code snippet above. #76 is a bigger improvement as far as tokenizerprefixtree.py is concerned.

As for the integration, though, the original is unusably slow and I'm not sure how useful a unit test would be. I let it finish now just to be sure, and with the latest Tokenizers library (0.15.2) on this Windows PC I'm currently on and the Qwen1.5 tokenizer model, the _build_regular_tokens_list function takes 1 hour and 4 minutes to complete. With the change above it's 37 milliseconds.

It also looks like the unit test would fail anyway because the current implementation isn't correct. Here's a snippet of the regular_tokens list, using Qwen:

idx str is_word_start_token
300 'as' True
301 'el' True
302 'ct' True
303 'nd' True
304 ' in' True
305 ' h' True
306 'ent' True
307 'id' True
308 ' n' True
309 'am' True

Vs. with the fix above:

idx str is_word_start_token
300 'as' False
301 'el' False
302 'ct' False
303 'nd' False
304 ' in' True
305 ' h' True
306 'ent' False
307 'id' False
308 ' n' True
309 'am' False

Here's some intermediates for Llama SPM tokenizer, token idx 293:

        tensor_after_0 = torch.tensor(token_0.tolist() + [token_idx], dtype=torch.long)  # tensor([31822, 31852,   293])
        decoded_after_0 = tokenizer.decode(tensor_after_0)[1:]                           # 'ion'
        decoded_regular = tokenizer.decode(token_0)                                      # '0'
        is_word_start_token = len(decoded_after_0) > len(decoded_regular)                # True

I guess decoded_regular should have been tokenizer.decode([token_idx]) instead? That's still problematic, I think (?), since encoding and decoding the string '0' won't always yield the string'0', and encoding a single token may or may not prepend an extra space, depending on normalization options in the tokenizer.

All of this is dealt with in ExLlamaV2's tokenizer, specifically to produce a reliable token -> string mapping that strips out all of the normalization.

noamgat commented 6 months ago

Hi, I'd like to take a deeper look into this. Can you send a reproducing snippet (which model correctly activates the Qwen1.5 tokenizer etc)?

turboderp commented 6 months ago

Sure, here's a snippet:

import sys, os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from pydantic import BaseModel
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser
import time
from typing import List, Tuple

# Initialize model, load only tokenizer

model_directory = "/mnt/str/models/smaug-72b-exl2/4.0bpw/"
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
tokenizer = ExLlamaV2Tokenizer(config)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)

class TestSchema(BaseModel):
    value_a: str
    value_b: str

schema_parser = JsonSchemaParser(TestSchema.schema())

# Create filter

print("Init filter...")
time_a = time.time()
lmfe_filter1 = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
time_b = time.time() - time_a
print(f"Duration: {time_b:.4f}")

# Patch function and repeat

def _build_regular_tokens_list_new(tokenizer: ExLlamaV2Tokenizer) -> List[Tuple[int, str, bool]]:
    vocab_size = tokenizer.tokenizer.vocab_size()
    all_special_ids = set(tokenizer.extended_id_to_piece.keys())
    all_special_ids.update({ tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.unk_token_id })
    id_to_piece = tokenizer.get_id_to_piece_list()
    regular_tokens = []
    for token_idx in range(vocab_size):
        if token_idx in all_special_ids:
            continue
        decoded = id_to_piece[token_idx]
        is_word_start_token = len(decoded) > 0 and decoded[0] == " "
        regular_tokens.append((token_idx, decoded, is_word_start_token))
    return regular_tokens

import lmformatenforcer.integrations.exllamav2
lmformatenforcer.integrations.exllamav2._build_regular_tokens_list = _build_regular_tokens_list_new

print("Init filter (patched)...")
time_a = time.time()
lmfe_filter2 = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
time_b = time.time() - time_a
print(f"Duration: {time_b:.4f}")

# Compare regular token lists

max_diff = 20
for a, b in zip(lmfe_filter1.token_enforcer.regular_tokens, lmfe_filter2.token_enforcer.regular_tokens):
    if a == b: continue
    print(f"A: {repr(a):30}  !=  B: {repr(b):30}")
    max_diff -= 1
    if max_diff == 0: break

Note that the behavior might depend on your version of the tokenizers library. I tested this with the latest release, 0.15.2, and it indeed takes around an hour to run for Qwen. Not sure what's happening in that library but decoding is extremely slow for some reason.

The behavior is the same for all Qwen models, and you can reproduce it with for instance this one, though it also shows up to a lesser extent with other models that rely in the tokenizers library. I think it's an algorithmic choice since it seems to be exponential (maybe quadratic) with the size of the vocabulary. Here are results for Mistral with SPM tokenizer:

Init filter...
Duration: 0.8121
Init filter (patched)...
Duration: 0.3278

And with HF tokenizer (reading tokenizer.json rather than tokenizer.model):

Init filter...
Duration: 59.6267
Init filter (patched)...
Duration: 0.3283

As for the output, it seems to be identical except the existing version recognizes any multi-character token as a word start token. Which seems wrong to me? Not sure what effect that has, since it still seems to constrain generation correctly regardless, at least for JSON.

There's also some example code here you could look at.

noamgat commented 5 months ago

Merged, released in v0.9.6. Thank you @turboderp and @bdashore3 !