dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
8.38k stars 425 forks source link

Accelerate the index construction process #795

Open aeft opened 5 months ago

aeft commented 5 months ago

What behavior of the library made you think about the improvement?

My understanding of the index construction process is that for each state in the FSM, we need to iterate through all tokens in the vocabulary (e.g., 32000) and find all valid tokens (i.e., those that can reach valid states).

How would you like it to behave?

First, let me give a example:

Considering the current state and its out edges, if its out edges are limited, for example, only Number edge, which means the state transistion happens only if the first character of the next token is a number. We can exploit this finding to avoid the huge vocabulary iteration for this state, i.e., we only iterate through tokens whose first character is a number.

We can apply this idea into other characters. To do this, we need to preprocess the vocabulary, i.e., find the corresponding tokens for each first character. There are two ways:

  1. We sort the vocabulary based on the first character. Then, for each character, we maintain an index pair (start_index, end_index), which indicates the token range.
  2. We don't change the original vocabulary but create a new sorted vocabulary.

The first way can save a little memory, and it depends on whether the order of original vocabulary matters.

The cost introduced by this method (let's call it method1) is the sorting cost. If the number of states is large or we can save the sorted vocabulary. The cost can be almost ignored (amortized).


Can we extend this idea further?

We can build a trie tree for the vocabulary. Then we use BFS to traverse the trie tree and maintain the corresponding state from the FSM. Thus, for tokens with the same prefix, we don't need to access the FSM multiple times. For example, for "hel", "hell", and "hello" as tokens, the "h->e->l" transition path only needs to be traversed once.

The real performance gain from method2 depends on the implementation because the original implementation uses numba.jit, and it may compensate for the performance of the naive algorithm.


Note: For method1 and method2, we still need to consider all states. Are there some methods to skip some states? I'm still figuring that out. (Maybe impossible, because in the worst case, the LLM generates one character at a time.)

I might have gotten something wrong. I am looking forward to your advice if you are available. Thank you! If you don't mind, I would like to implement this when the method is mature.

rlouf commented 5 months ago

It's hard to evaluate the performance improvement ahead of time; at this point there is no other choice but to implement and benchmark.

Afaiu this is close to what was proposed in https://github.com/outlines-dev/outlines/pull/507

aeft commented 5 months ago

It seems like the trie implementation in #507 didn't get merged. Are there some reasons?

rlouf commented 5 months ago

Said PR was making several changes at the same time, with different impacts on the compile time, which made it difficult to evaluate the performance impact. We asked for the PR to be split, but the author never got around to it. Happy to review a PR that implements a single change and with proper and extensive benchmarks.

We have to be conservative with that part of the code and make sure we don't introduce performance regressions.