Dan-wanna-M / formatron

Formatron empowers everyone to control the format of language models' output with minimal overhead.
MIT License
163 stars 6 forks source link

Efficient batched inference #3

Open Dan-wanna-M opened 3 months ago

Dan-wanna-M commented 3 months ago

While we support batched inference like other constrained decoding libraries, the current implementation can be parallelized further. In particular, we can mask logits in batch and run several kbnf engine in parallel.

turboderp commented 3 months ago

I'm not sure what the control flow is, but at least ExLlamaV2 does multithreaded sampling for batches, so if you can disable the GIL during compute_allowed_tokens that should go a long way.

Batching logits for processing would be problematic since ExLlamaV2 allows for concurrent jobs with completely different sampling/filtering settings.

Dan-wanna-M commented 3 months ago

I'm not sure what the control flow is, but at least ExLlamaV2 does multithreaded sampling for batches, so if you can disable the GIL during compute_allowed_tokens that should go a long way.

@turboderp Could you clarify what you mean by "disable the GIL during compute_allowed_tokens"? As far as I know the rust extension itself is not affected by GIL(since it does not use py type internally) and it is not possible to turn off GIL for a Python method?

turboderp commented 3 months ago

I'm not sure how it translates to Rust, but it's described here for (C++) extension functions.

Basically wrap a block of code in Py_BEGIN_ALLOW_THREADS and Py_END_ALLOW_THREADS, and as long as it isn't manipulating any Python objects it will allow other Python threads to continue running while the extension code does some blocking I/O or computation.

Dan-wanna-M commented 3 months ago

I'm not sure how it translates to Rust, but it's described here for (C++) extension functions.

Basically wrap a block of code in Py_BEGIN_ALLOW_THREADS and Py_END_ALLOW_THREADS, and as long as it isn't manipulating any Python objects it will allow other Python threads to continue running while the extension code does some blocking I/O or computation.

Got it, I need to check how to do this in pyo3.

Dan-wanna-M commented 3 months ago

https://github.com/vllm-project/vllm/issues/3567 is relevant as well. Hashing a large number of python ints does not sound like a good idea either, maybe we can leverage the bitset to get some efficient bytes representation.