noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.01k stars 46 forks source link

Performance Optimizations? #28

Closed simon-mo closed 4 months ago

simon-mo commented 7 months ago

Hi @noamgat,

First, thank you for contributing to vLLM for the logits processors PR. I want to dive deeper into the performance of this library. Is there a way to speed things up by pre-computation of the mask per state machine similar to the outline's approach?

(I also opened an issue asking them to leverage tokenizer prefix-tree in pre-computation pass: https://github.com/outlines-dev/outlines/issues/383)

I think there will be a lot of value in bring the performance overhead as close to 0% as possible.

simon-mo commented 7 months ago

Additionally, I found some overhead exists in the logits processors that moving generated mask from CPU to GPU is pretty time consuming. Doing it on very token slow things down 20-50%.

noamgat commented 7 months ago

Hi! Thanks for taking interest! I think the second point (CPU->GPU transfer) is more important, as for the first point, the calculation has to happen sometimes. How do you see that the CPU->GPU transfer is causing a slowdown? From my understanding, it would happen in the line

        mask[allowed_tokens] = 0

Is this the case?

GuyAglionby commented 7 months ago

I haven't properly profiled, but manually looking at sampled stack traces it looks it's that line has a lot of time spent on it

simon-mo commented 7 months ago

Yeah I think depending on the exact configuration:

GuyAglionby commented 7 months ago

I'm assuming it wouldn't work to implement your second option by changing vllm.py to define mask in the constructor, put it to GPU there, and just update it as you describe in __call__? Sorry, I don't have a deep understanding of how vLLM works.

I ran a quick experiment on Llama2 7B on a 40GB A100 with vLLM, it looks like using JsonSchemaParser slows generation from ~1720 to ~330 tok/s generated. Keen to help look in to this

simon-mo commented 7 months ago

Exactly! I think it will have to look something like this:

class StatefulLogitsProcesser:
    def __init__(self, ...):
        self.mask = torch.zeros(..., device="cuda")
     def __call__(self, prev_tokens, curr_logits):
         ...

Currently, there's a GPU malloc at each decodes

https://github.com/noamgat/lm-format-enforcer/blob/221b5f228cac6244b5d1a9cde7582107c602dafd/lmformatenforcer/integrations/vllm.py#L22C16-L22C16

noamgat commented 7 months ago

Thanks for the feedbacks everyone! I have a day set next week to work on lm-format-enforcer, I will take care of some issues that need work. Hopefully I'll find an elegant way to work around this one as well.

noamgat commented 7 months ago

I released a version which contains a fix to this problem. Can you run using the latest version and check if the performance improves for your use case?

GuyAglionby commented 7 months ago

Thanks! I will test and get back to you

noamgat commented 4 months ago

Closing due to inactivity.