noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.42k stars 65 forks source link

Add support for string min/max length to json_freetext #43

Closed elonen closed 9 months ago

elonen commented 9 months ago

This PR allows json_freetext optimization to be used with JSON schema strings that have a max length by making a separate cache for each possible token length encountered in the LLM tokenizer:

Fixes #42

elonen commented 9 months ago

Output example:

### Input:
Write a knock-knock joke in JSON format. The joke should have a preamble, a question, a name, a who, and a punchline, in that order.

### Response:

Answer, with plain str:
{
  "preamble": "Knock knock!",
  "question": "Who's there?",
  "name": "Banana",
  "who": "Banana who?",
  "punchline": "Banana-na-na-na-na-na-na-na-na-na-na!"
}
Answer, with constr (8 chars):
{
  "preamble": "Knock-kn",
  "question": "Who's",
  "name": "there",
  "who": "Orange",
  "punchline": "Orange,y"
}
Plain: 2.50 s, Constr: 1.81 s
noamgat commented 9 months ago

Interesting idea, should work. Some feedback:

elonen commented 9 months ago
  1. For min_lengt, I think the easiest way would be to duplicate the cached allowlists by introduction a json_non_terminating_freetext_tokens in addition to json_freetext_tokens. Those lists would contain all valid JSON sub-string tokens that don't end in a " (since the ones that have them mid-token are already filtered out). The LLM should be allowed to pick any token, even the short ones, as long as it doesn't terminate the string before JsonSchemaParser determines it's long enough.
  2. The performance hit should be mininal. It passes integer max_size size as a token length limit, which causes freetext's dict selection loop to always select the dict with most tokens. The sorted() call could be optimized away by using an ordered dict instead (not that sorting 16 [max token length I got] integers - even in Python - per token shows anywhere compared to everything else that's required to generate a token).
  3. I'll have a look.
elonen commented 9 months ago

BTW, I think this check in freetext's initial token filtering currently causes suboptimal outputs:

has_quote_before_end = '"' in decoded[0:-1]

It should allow \" AND in case the string belongs to a list, ", " (which I think I've seen as a single token). I've seen a lot of outputs where the LLM is adding unnecessary newlines to a string list, probably because it would've liked to pick ", " and it was prohibited by the freetext mode. In the next step the LLM probably thinks it didn't pick that token in the last step because it wanted to add some more whitespaces, which in practice means a newline since nobody adds more than one space between items in a single-line JSON list. Depending on the tokenizer, a similar issue might arise with ": " for JSON dicts, but I haven't seen that one so far.

The \" case should be fairly easy to add I guess, but I'm not sure what kind of black magic would be required to allow ", " in freetext. In addition to string length limit(s) that JsonSchemaParser now passes it in this PR, it would need to get information that the current string is a member of a JSON list, but not the last item.

EDIT: Llama tokenizer doesn't apparently have ", " but ",":

Screenshot 2023-12-21 at 13 09 06

Not sure if I've hallucinated it or what, but the the principle holds and this I think is the most likely explanation to the wonky outputs I've seen.

Anyway, that's an issue separate from this PR though.

noamgat commented 9 months ago

That is allowed, see how tokenenforcer currently uses the json_freetext_tokens concept:

https://github.com/noamgat/lm-format-enforcer/blob/39ad002d6fcb2093dd4214db682d1d407b1134a6/lmformatenforcer/tokenenforcer.py#L130

It takes the json freetext tokens list, but does recurse into the node of the tokenizer prefix tree that starts with ", because you need the context of the specific parsing point to know if "," is allowed (value of last key in object -> not allowed, etc)

elonen commented 9 months ago

Ok, the latest commit optimizes away the inference-time sorted() call by using an OrderedDict.

I also changed the current_parser.min_length is None test into (current_parser.min_length is None or len(current_parser.parsed_string) >= current_parser.min_length). This already optimizes cases where min_length is small (e.g. 1-2 characters), because after the minimum is reached, it can use freetext until max_length is reached (if ever).

Optimizing min_length further

Unlike what I wrote earlier, to optimize for large min_length, I think you'd actually need a 2D dictionary of allowlists, that is O(N^2) of them, where N=max token length (around 16 for current tokenizers). Each would contain tokens that are "at least X chars (excluding trailing "), and at most Y chars". For max token length of 16, that's 256 cached lists (or ~128 if space-optimized into triangle form because Y>=X always).

That's at least 16-32MB of cached integer lists in RAM (depending on how Python optimizes them), which is not nothing, but probably fine since the models themselves tend to be way bigger. The logic is not trivial though, so I didn't attempt to implement the scheme for now.

Alternatively, if using the 30x slower generation method is fine for at most max token length characters per min_length constrained JSON string, you could get away with just one json_non_terminating_freetext_tokens list. It would contain all tokens (of any length) that don't terminate the string, and would be sampled from until there are less than max token length characters left before min_length is reached. Then the algorithm would temporarily switch to the slower method, and switch back to freetext once again when min_length is exceeded.

elonen commented 9 months ago

Sorry, I got my brain in a knot thinking about the further min_length optimization, have to think and test some code before coming back to it.

The max_length optimization should work fine though, as should the half-measure for min_length, so if you want to merge this as-is for now, I can make another PR once I'm confident with a more general min_length solution. Otherwise I'll probably make an additional commit to this if the PR still open at that point.

elonen commented 9 months ago

Progress! I've designed an algorithm that creates an allow list for all valid (min_remaining, max_allowed_len) pairs, where:

min_remaining = min(max_token_len, max(0, min_length - cur_len))  # no EOS allowed before this
max_allowed_len = min(max_token_len, max_length - cur_len)  # max new (non-EOS) characters allowed

allowed_tokens = cached_lists[(min_remaining, token_max_len)]  # simple dict lookup

Below are my test results with the Llama tokenizer.

I'll adapt this for lm-format-enforcer next.

Number of allowlists: 153
Number of unique allowlists: 75

### Number of tokens in each list:
max_len         0     1     2     3      4      5      6      7      8      9      10     11     12     13     14     15     16
min_remaining
0                1  3286  5262  9473  14750  19329  22934  25786  27897  29342  30271  30836  31142  31312  31384  31413  31431
1              NaN  3285  5261  9472  14749  19328  22933  25785  27896  29341  30270  30835  31141  31311  31383  31412  31430
2              NaN   NaN  5241  9452  14729  19308  22913  25765  27876  29321  30250  30815  31121  31291  31363  31392  31410
3              NaN   NaN   NaN  9440  14717  19296  22901  25753  27864  29309  30238  30803  31109  31279  31351  31380  31398
4              NaN   NaN   NaN   NaN  14714  19293  22898  25750  27861  29306  30235  30800  31106  31276  31348  31377  31395
5              NaN   NaN   NaN   NaN    NaN  19293  22898  25750  27861  29306  30235  30800  31106  31276  31348  31377  31395
6              NaN   NaN   NaN   NaN    NaN    NaN  22898  25750  27861  29306  30235  30800  31106  31276  31348  31377  31395
7              NaN   NaN   NaN   NaN    NaN    NaN    NaN  25750  27861  29306  30235  30800  31106  31276  31348  31377  31395
8              NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN  27861  29306  30235  30800  31106  31276  31348  31377  31395
9              NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN  29306  30235  30800  31106  31276  31348  31377  31395
10             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN  30235  30800  31106  31276  31348  31377  31395
11             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN  30800  31106  31276  31348  31377  31395
12             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN  31106  31276  31348  31377  31395
13             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN  31276  31348  31377  31395
14             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN  31348  31377  31395
15             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN  31377  31395
16             NaN   NaN   NaN   NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN    NaN  31395

### Number of tokens with trailing EOS:
max_len         0    1    2    3    4    5    6    7    8    9    10   11   12   13   14   15  16
min_remaining
0                1   21   33   36   36   36   36   36   36   36   36   36   36   36   36   36  36
1              NaN   20   32   35   35   35   35   35   35   35   35   35   35   35   35   35  35
2              NaN  NaN   12   15   15   15   15   15   15   15   15   15   15   15   15   15  15
3              NaN  NaN  NaN    3    3    3    3    3    3    3    3    3    3    3    3    3   3
4              NaN  NaN  NaN  NaN    0    0    0    0    0    0    0    0    0    0    0    0   0
5              NaN  NaN  NaN  NaN  NaN    0    0    0    0    0    0    0    0    0    0    0   0
6              NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0    0    0    0    0    0    0   0
7              NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0    0    0    0    0    0   0
8              NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0    0    0    0    0   0
9              NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0    0    0    0   0
10             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0    0    0   0
11             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0    0   0
12             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0    0   0
13             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0    0   0
14             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0    0   0
15             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    0   0
16             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   0

### ID of each list (to illustrate deduplication):
max_len         0    1    2    3    4    5    6    7    8    9    10   11   12   13   14   15  16
min_remaining
0                0    6    8    7    9   10   11   12   13   42   43   44   45   46   47   48  49
1              NaN    4    5   50   51   52   65   14   53   66   28   54   55   56    3   57  15
2              NaN  NaN   67   29   30   62   68   69   31   32   16   33   34   35   36   37  58
3              NaN  NaN  NaN   38   21   70   17   59   71   18   19   22   72   73   74   60   1
4              NaN  NaN  NaN  NaN   40   24   41   61   64    2   25   63   20   23   26   27  39
5              NaN  NaN  NaN  NaN  NaN   24   41   61   64    2   25   63   20   23   26   27  39
6              NaN  NaN  NaN  NaN  NaN  NaN   41   61   64    2   25   63   20   23   26   27  39
7              NaN  NaN  NaN  NaN  NaN  NaN  NaN   61   64    2   25   63   20   23   26   27  39
8              NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   64    2   25   63   20   23   26   27  39
9              NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN    2   25   63   20   23   26   27  39
10             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   25   63   20   23   26   27  39
11             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   63   20   23   26   27  39
12             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   20   23   26   27  39
13             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   23   26   27  39
14             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   26   27  39
15             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN   27  39
16             NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  39

All 10000 tests passed!

The algorithm is actually pretty simple:

def precalc_allowlists(all_tokens: List[str]) -> Dict[Tuple[int, int], Tuple[str, ...]]:

    def valid_for_min_remaining(token, min_remaining):
        return not token.endswith(EOS) or len(token.rstrip(EOS)) >= min_remaining

    def valid_for_max_len(token, max_len):
        return len(token.rstrip(EOS)) <= max_len

    # Make a 2D array of constrained allowlists, indexed by tuple `(min_remaining, max_len)`
    token_lists = {} 
    max_token_len = max(len(token) for token in all_tokens)
    for min_remaining in range(max_token_len + 1):
        for max_len in range(max_token_len + 1):
            if max_len >= min_remaining:  # Skip combinations that are never used
                token_lists[(min_remaining, max_len)] = tuple(sorted([
                    token for token in all_tokens
                    if valid_for_min_remaining(token, min_remaining) and valid_for_max_len(token, max_len)
                ]))

    # Deduplicate the lists to save RAM, as many of them will be identical
    unique_lists = set(token_lists.values())
    for key, lst in token_lists.items():
        for uniq in unique_lists:
            if len(uniq) == len(lst) and uniq == lst:
                token_lists[key] = uniq
                break

    return token_lists
elonen commented 9 months ago

Ok, I've committed a revised version which fully optimizes min_length, too.

elonen commented 9 months ago

Added some unit tests, squashed and rebased on current HEAD. Ready for your new review.

noamgat commented 9 months ago

Thank you for your contribution! I will review it in a couple of days.

noamgat commented 9 months ago

Looks great! Added a few comments, after they are resolved I will merge.

elonen commented 9 months ago

Looks great! Added a few comments, after they are resolved I will merge.

Where did you add the notes? I couldn't find them.

elonen commented 9 months ago

Fixes done. I also moved the json.loads test inside JsonFreetextTokenCache, which took care of the skipped new_word_tokens addition.

noamgat commented 9 months ago

Merged! Thanks for the contribution! I will do some E2E testing with the sample notebooks in the coming days, and if no surprises will pop up, I'll release a new version.