Dan-wanna-M / formatron

Formatron empowers everyone to control the format of language models' output with minimal overhead.
MIT License
102 stars 3 forks source link

ExLlamaV2 return types and multithreading #14

Closed turboderp closed 5 days ago

turboderp commented 2 weeks ago

I'll just put this in an issue to avoid getting things too mixed up. I've added list as a valid return type from ExLlamaV2Filter.next() in the dev branch, and it does indeed reduce the overhead quite a bit.

    # Old version for compatibility
    def next_set(self) -> typing.Tuple[typing.Set[int], typing.Set[int]]:
        if self._formatter.is_completed():
            return {self.tokenizer.eos_token_id}, set()
        self._formatter.compute_allowed_tokens()
        self._pass_tokens.clear()
        self._pass_tokens.update(self._formatter.get_allowed_tokens_since_last_computation())
        return self._pass_tokens, set()

    def next(self) -> typing.Tuple[typing.Sequence[int], typing.Sequence[int]]:
        # Kludge to maintain compatibility with exllamav2 <= 0.2.0
        if not hasattr(self, "allow_return_type_list"):
            return self.next_set()
        if self._formatter.is_completed():
            return [self.tokenizer.eos_token_id], []
        self._formatter.compute_allowed_tokens()
        return self._formatter.get_allowed_tokens_since_last_computation(), []

Aside from the ugly check to avoid breaking existing versions of ExLlama, it should all be transparent from Formatron's side. If Formatron returns a list and the sampler needs a set (because it's running multiple filters or whatever) it will convert it itself.

I would PR the above, but I'm working on a somewhat broken branch of Formatron, as I couldn't get my test code to work after the refactor yesterday (JsonGenerator disappeared and I'm not sure what the interface is at the moment, since the examples aren't updated yet.)

Anyway, I then went on to add asynchronous filter evaluation. So each filter in a batch (i.e. all the logic within next()) runs in its own thread, starting immediately after the CUDA queue has been built for the following forward pass. Here's the resulting overhead (reduction in tokens/second from Formatron, L3-8B 4bpw on 4090):

bsz Original Threads List Threads+List
1 8.24% 5.63% 3.22% 2.38%
2 5.44% 6.16% 5.23% 3.73%
4 10.52% 1.01% 6.53% 0.18%
8 13.65% 4.50% 3.60% -1.41%

(Overhead does actually become negative in some situations, because sampling is skipped whenever the filter only allows one token, making constrained sampling potentially faster at times.)

It's worth noting that I'm testing with some very basic filters. I would expect a more dramatic difference from multithreading for more complex JSON schemas etc.

Since there's a little bit of overhead from multithreading, you should add this method to the filter class as well:

    def use_background_worker(self) -> bool:
        return True
turboderp commented 2 weeks ago

Oh shoot, I hadn't updated KBNF after the no_gil stuff. It helps:

bsz Original Threads List Threads+List T+L+upd. KBNF
1 8.24% 5.63% 3.22% 2.38% 1.98%
2 5.44% 6.16% 5.23% 3.73% 2.65%
4 10.52% 1.01% 6.53% 0.18% -0.76%
8 13.65% 4.50% 3.60% -1.41% -2.17%
Dan-wanna-M commented 2 weeks ago

I would PR the above, but I'm working on a somewhat broken branch of Formatron, as I couldn't get my test code to work after the refactor yesterday (JsonGenerator disappeared and I'm not sure what the interface is at the moment, since the examples aren't updated yet.)

I planned to update my readme examples(and release notes which helps migration) once Formatron 0.4.0 is published. In the meanwhile, you can take a look at this file to check out the newest interface.

That said, your modified integration itself should work regardless of the refactor, since the refactor is mostly about making the logic and usage of converting a Schema into a Formatter cleaner. You may want to use Formatron=v0.3.3(which does not break API compatibility) for testing performance though, since this commit fixed an important overhead issue where the exllamav2 generator is forced to output eos_token constantly until max_new_tokens.

Dan-wanna-M commented 5 days ago

Incorporated in v0.4.0.