Closed simon-mo closed 4 months ago
Additionally, I found some overhead exists in the logits processors that moving generated mask from CPU to GPU is pretty time consuming. Doing it on very token slow things down 20-50%.
Hi! Thanks for taking interest! I think the second point (CPU->GPU transfer) is more important, as for the first point, the calculation has to happen sometimes. How do you see that the CPU->GPU transfer is causing a slowdown? From my understanding, it would happen in the line
mask[allowed_tokens] = 0
Is this the case?
I haven't properly profiled, but manually looking at sampled stack traces it looks it's that line has a lot of time spent on it
Yeah I think depending on the exact configuration:
I'm assuming it wouldn't work to implement your second option by changing vllm.py to define mask
in the constructor, put it to GPU there, and just update it as you describe in __call__
? Sorry, I don't have a deep understanding of how vLLM works.
I ran a quick experiment on Llama2 7B on a 40GB A100 with vLLM, it looks like using JsonSchemaParser
slows generation from ~1720 to ~330 tok/s generated. Keen to help look in to this
Exactly! I think it will have to look something like this:
class StatefulLogitsProcesser:
def __init__(self, ...):
self.mask = torch.zeros(..., device="cuda")
def __call__(self, prev_tokens, curr_logits):
...
Currently, there's a GPU malloc at each decodes
Thanks for the feedbacks everyone! I have a day set next week to work on lm-format-enforcer, I will take care of some issues that need work. Hopefully I'll find an elegant way to work around this one as well.
I released a version which contains a fix to this problem. Can you run using the latest version and check if the performance improves for your use case?
Thanks! I will test and get back to you
Closing due to inactivity.
Hi @noamgat,
First, thank you for contributing to vLLM for the logits processors PR. I want to dive deeper into the performance of this library. Is there a way to speed things up by pre-computation of the mask per state machine similar to the outline's approach?
(I also opened an issue asking them to leverage tokenizer prefix-tree in pre-computation pass: https://github.com/outlines-dev/outlines/issues/383)
I think there will be a lot of value in bring the performance overhead as close to 0% as possible.