noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.42k stars 65 forks source link

Generation is too slow when using RegexParser(r'.*') or RegexParser(r'.+') #84

Open ZhouGongZaiShi opened 6 months ago

ZhouGongZaiShi commented 6 months ago
output_regex = r"\d*.*"
parser = RegexParser(output_regex)
logits_processor = build_vllm_logits_processor(llm_tokenizer_data, parser, analyze=False)
sampling_params.logits_processors = [logits_processor]

On A100 GPU, use vllm to load the local model and use lmformatenforcer.

If output_regex=r"\d*\w*", tokens per second are about 200 If output_regex=r"\d*.*", tokens per second are about 5.

The speed dropped by a factor of forty. Sometimes "." is necessary for regular expressions.

noamgat commented 6 months ago

Hi, I was not able to reproduce this issue on my computer. Can you check if this slowdown also occurs if you use the same logits processor instance for multiple requests?

RegexParser is implemented behind the scenes via a state machine. The first time a state is encountered, all of the legal tokens are calculated. This can take a bit of time. However, next time, its not calculated at all. In the regex you gave as an example, there are only 3 states to the FSM, so after the first generation, the next ones should have negligible impact.

Can you check if this is the case?

goutham794 commented 5 months ago

Hello,

I'm trying to use a regex controlled generation, for me the processing does not even start when I have a huge batch, it worked when I gave it a single example, albeit slow. My inital regex was to match a valid Python list of strings which I simplified, but it still does not even start :

list_regex = r'\[.*\]'
parser = RegexParser(list_regex)

logits_processor = build_vllm_logits_processor(tokenizer_data, parser)

sampling_params = SamplingParams(max_tokens=100, logits_processors=[logits_processor])
results = llm.generate([p['text'] for p in dataset], sampling_params=sampling_params)
goutham794 commented 5 months ago

It started after a long time, and the throughput is much lower - at 2.5 it/s compared to 110 it/s without the logit processing.