Closed EricLBuehler closed 7 months ago
@EricLBuehler do you have further steps of it planned?
Looking at similar designs:
Outlines generates a FSM from regex (using https://github.com/MegaIng/interegular), transforms the FSM from text-based to token-based https://github.com/outlines-dev/outlines/blob/main/outlines/fsm/regex.py#L47 which is later used to influence the logits by vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/guided_logits_processors.py#L28
https://github.com/sgl-project/sglang also uses Outlines FSM but SGLang also implements a jump-forward optimization mechanism https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/constrained/jump_forward.py#L8 https://lmsys.org/blog/2024-02-05-compressed-fsm/
Llama.cpp seems to compute the state of the FSM as it walks it, doing the same logit operations https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L12510 after building this data structure https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L11954 after parsing the grammar
vLLM roadmap also shows they plan to integrate lm-format-encoder, that does not pre-build the FSM https://github.com/vllm-project/vllm/issues/3713
And also AICI https://github.com/vllm-project/vllm/pull/2888
AICI is a Rust library, I wonder if it wouldn't be simpler for mistral.rs to implement it, or implement it alongside other setups.
@lucasavila00, I plan on working on this next week. Would you be able to put together a skeleton PR? This looks like an exciting development.
Sure, I'll try to integrate AICI over the weekend.
Implementing AICI has shown to be too complex for me to build quickly.
What I learned on https://github.com/EricLBuehler/mistral.rs/pull/81 :
To run the server on that PR use
./target/release/mistralrs-server --aicirt=PATH_TO_AICIRT --port 1234 --log target/output.txt mistral-gguf
where PATH_TO_AICIRT is the binary from Github Releases https://github.com/microsoft/aici/releases/tag/v0.0.10
Use mistral-gguf
because I hardcoded some settings instead of building the code that reads them from the config.json, tokenizer.json. These places were annotated with TODOs.
It only partially implemented the mid
part of the protocol. pre
and post
have not been implemented.
Thank you for working on this! We will probably begin work on this early next week.
Like you said, AICI is pretty complicated, and we will probably not use it. I will look into the SGLang method you mentioned. If you have any ideas for implementing grammars, please let me know!
I have this https://github.com/lucasavila00/LmScript project that supports both vLLM and SGLang. So I really only know about these.
Both use Outlines FSM.
The RadixAttention used by SGLang makes it faster to use prompts that generate many small pieces. For example, an ad-hoc structured XML generation https://github.com/lucasavila00/LmScript/blob/main/packages/client/src/backends/executor.ts#L78-L136
Also, SGLang has 2 different select-from-choices
operations.
If one uses a regex with multiple options the backend will eagerly-per-token select the most probable choice. This might be unintuitive (eg: https://github.com/guidance-ai/guidance/issues/564)
If one uses the native SGLang select
operation then it calculates the probability for each of the options as whole: https://github.com/sgl-project/sglang/blob/ff99c38a0711ee82926840129db840a70e91f0d9/python/sglang/backend/runtime_endpoint.py#L191-L242
SGLang select
is amazing regarding result quality, and it is implemented with just a few extra settings to the server, that tell the server which tokens it should return the logprobs of. Of course, this only works efficiently because the RadixAttention cache can re-use the computation of the common prefix.
I like SGLang a lot. The only issue is that it takes a long time to pre-compile the regexes, which are then saved to disk and available for re-use. I agree with https://github.com/vllm-project/vllm/issues/3713 "For endpoint products running model as a service with customers supplying many different schemas, the cost might not be acceptable."
I would like if SGLang's design could compile regexes fast or reject them due to excessive complexity, that's the major flaw I see with it. Besides bad error messages when it receives invalid regexes and so on.
I did look into rust's regex-automata for the regex compilation https://github.com/rust-lang/regex/blob/master/regex-automata/src/dfa/mod.rs#L265-L270 but as the linked line says it is hard to make this compilation efficient.
Ah, SGLang doesn't do token healing. Only Guidance does https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38
It would be amazing if SGLang's design also had token healing.
It does require a lot of changes to be supported though, one needs to remove parts of the initial prompt and then let the model re-generate the removed part.
If this is a desired feature, you might want to keep it in mind while building the other parts to make an eventual implementation easier.
Other interesting things I found:
Latest version of TGI https://github.com/huggingface/text-generation-inference also uses Outlines FSM, supports regex decoding
Llama.cpp has performance issues on its grammar implementation https://github.com/ggerganov/llama.cpp/issues/4218#issuecomment-1836540046
Kalosm implemented regex-based decoding using the regex-automata DFA, as I mentioned above https://github.com/huggingface/candle/issues/1945#issuecomment-2027242005
vLLM has cache prefix now https://github.com/vllm-project/vllm/issues/2614 (Unfortunately I couldn't test it as my RTX2070 is too old for it, so I can't tell if this works as well as SGLang Radix Attention. I'm in the process of replacing the GPU and I should be able to compare the approaches in a week or so)
Thanks for the links. I really like the idea of using BNF, maybe it could be converted to a regex. After looking at the Kalosm implementation and this issue, I think we could implement this with the logit bias and just parse the BNF. Reading the llama.cpp issue, I think there are a few considerations we need to take into account when implementing: 1) Use a polynomial bounded time parsing algorithm 2) Run normal sampling and if that returns a token that does not match the grammar, only then spend the time computing the grammar logit bias
Potentially, we could use a regex DFA to handle the matching. What do you think about that plan?
Regarding prompt caching, I think that is workable, as we could just implement an eviction policy for KV caches. However, that is a separate topic, and perhaps you could raise a tracking issue for it?
I really like the idea of using BNF
Outlines has an example of .lark
files. Terminals are regexes and one can build the parser on top of the terminals, https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/common.lark https://github.com/outlines-dev/outlines/blob/main/outlines/fsm/guide.py
Potentially, we could use a regex DFA to handle the matching. What do you think about that plan?
Generating the FSM for a "big" regex expression is slow. This SGLang example https://github.com/sgl-project/sglang?tab=readme-ov-file#json-decoding of a JSON object with ~10 fields takes a minute to compile.
A mixture of an "interpreter" for the BNF/gramamar, where each terminal is a compiled regex could work. We would cache the regexes FSM and build big JSON objects without a big compilation delay. This is how Outlines implements the .lark
files.
Regarding prompt caching, I think that is workable, as we could just implement an eviction policy for KV caches. However, that is a separate topic, and perhaps you could raise a tracking issue for it?
Sure.
True, compiling a big BNF would be very costly, I like the idea of only compiling the terminals. I wonder if we can multithread the logit bias calculation with a thread pool, we can do something like:
And then accumulate them:
The vocab size for a Mistral model is 32000, so iterating over that would be expensive in the sampling hotloop! We are currently ~3 ms/T slower than llama.cpp on an A10 + Mistral Q4_K_M GGUF, so I want to improve performance.
On performance, I think the explanation of the AICI protocol is good https://github.com/microsoft/aici/blob/main/docs/aicirt-proto.md:
...
- the LLM schedules some of the non-supended sequences to compute logits for
- the LLM informs AICIrt about the scheduled sequences;...
- the LLM starts computing logits
- AICIrt sends the logit biases to the LLM
- LLM adds computed logits and biases and samples tokens
...
It should be done in a different thread, concurrently to the logit calculation, and it should not start calculation at sampling time, but as soon as the GPU starts the logit step.
According to AICI, if done this way, there is little cost: https://github.com/microsoft/aici?tab=readme-ov-file#performance
For example, computing allowed token set in the 32000-strong vocabulary of Llama model takes:
- about 2.0ms for Yacc grammar of the C programming language
- about 0.3ms for a regular expression
- about 0.2ms for a substring constraint, from 4kB string
Since logits calculation usually take longer than that, and the servers usually have many CPUs, there is no cost.
Ah, that is great: that way the logit biases are ready before the logits are! Even with this, I think it would still be best to sample w/o applying the logit biases first as an optimization, as that needs to be sequential and may iterate over most of the vocab size. In fact, while the initial sampling is occurring a worker thread can apply the logit bias.
If you want to open a PR laying the groundwork for some of this, such as the logit bias application and dual-sampling setup, please feel free! I am hoping to implement grammar support this week.
Awesome!
I'll have a busy week. I fear I'll only be available next weekend.
Once I have the time I'll look around for what I can help with.
I was reading a bit more about DFA, FSM and so on thinking about how to implement this, and I stumbled upon the details of the AICI ABI:
https://github.com/microsoft/aici/tree/main/controllers/aici_abi
It works pretty much like Kalosm where it runs the regex for every logit. However, they do it on a token trie, so if tokens share a common prefix they won't require re-computation.
They provide the grammar with regex terminals: https://github.com/microsoft/aici/tree/main/controllers/aici_abi#lr1-grammars
Regex: https://github.com/microsoft/aici/tree/main/controllers/aici_abi#regular-expressions
And low level API like required in the sampler, that can be used for both the first optimistic check and the re-sampling too: https://github.com/microsoft/aici/tree/main/controllers/aici_abi#low-level-interface
It requires an instance of TokTrie that can be built like I did in the previous AICI MR https://github.com/EricLBuehler/mistral.rs/pull/81/files#diff-8e7ab085145c61f5613962a6b30db55a222fa6bf6e432e82df9ab90dbfb4627aR900-R903
And a stateful instance of the Recognizer per sequence, like this regex https://github.com/microsoft/aici/blob/main/controllers/aici_abi/src/rx.rs#L20 or grammar https://github.com/microsoft/aici/blob/main/controllers/aici_abi/src/cfg.rs
@EricLBuehler what are your thoughts?
To me it looks like we can do this by adding aici-abi as a dependency. (notice that the previous MR also added aici-runtime, which I'm not proposing to add anymore)
We could also start with copying the code and editing it. However after reading it for a bit I think using it as a dependency should be good enough...
I also published kalosm-sample as a separate library if you want to re-use the logic. It currently walks the dfa as it generates tokens so each individual token will take a constant amount of time to validate regardless of the generation length. If you are generating a series of different but similar sequences I would be happy to merge support for a prompt cache.
One technique kalosm uses that I haven't seen mentioned here is constraint accelerated batching. If you see that the next n
tokens must be a specific string, you can load all of those tokens into the kv cache at once in a larger batch to accelerate generation
@lucasavila00, I think it would be good to do something similar to the token trie + kalosm-style regex matching.
@ealmloff, thanks for the link! If you would be able to merge support for a prompt cache, that would be much appreciated! I was wondering, in the kalosm internals, do you apply the Parser to every token in the vocab to make a logit bias, or do you do someting else?
I was wondering, in the kalosm internals, do you apply the Parser to every token in the vocab to make a logit bias, or do you do something else?
Currently, yes I do that here. If you use top_k
, you could only parse tokens until you get at least k
tokens instead of every logit
Implemented in #103, thank you @lucasavila00 !
We will implement based on this.
The idea is as follows, given parsed BNF.
0) While the model is calculating the logits, prepare the logit bias on a worker thread (from a pool). 1) Run normal sampling first: if the returned token is valid grammar, avoid applying the logit bias 2) During normal sampling, apply the logit bias on a worker thread (from a pool). 3) If the normal sampling produced a token that would be invalid, rerun with the applied logit bias.