Open maximzubkov opened 7 months ago
Good find! I just merged the PR. You are correct in diagnosing the issue. The right fix would be cloning the logits processor for every sequence in the group. Contribution welcomed indeed!
Hello, @simon-mo! Thank you for the prompt response, I implemented the fix to it and left some comments within the code to explain why I made certain design decisions, see the following RP
I also slightly updated the script to reproduce the issue to take into account the case when there are multiple requests to the engine via [prompt, prompt]
(this also used to fail with the same bug when I tested it)
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.model_executor.guided_logits_processors import CFGLogitsProcessor
model = "microsoft/phi_1"
prompt = "Writa a simple SQL query to the table table_2 checking if col_1 equals to 1"
tokenizer = AutoTokenizer.from_pretrained(model)
simple_sql_grammar = """
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
n=10,
max_tokens=512,
logits_processors=[CFGLogitsProcessor(simple_sql_grammar, tokenizer)]
)
llm = LLM(model=model, dtype="auto")
outputs = llm.generate([prompt, prompt], sampling_params)
print([
output_.text for output_ in outputs[0].outputs
])
Your current environment
I used Docker:
On the server with 4x NVIDIA RTX A4000
π Describe the bug
I tested the
Context Free Grammar
with vLLM and askedphi-1
to generate a simple SQL query, following this test from a recent PRAlthough the
CFGLogitsProcessor
feature is not merged yet (the PR is still opened), the above example worked fine when I usedn=1
inSamplingParams
. However, when I switched ton=10
, my code failed with:Diving deeper into the code, I figured that this bug would occur with
RegexLogitsProcessor
andJSONLogitsProcessor
as well due to the current implementation of _apply_logits_processors. BothRegexLogitsProcessor
,JSONLogitsProcessor
, andCFGLogitsProcessor
are calling self.fsm.allowed_token_ids for every sequence considered by the beam search, and perhaps due to this fact the cache is stored incorrectly (as you can see from the bug, the cache is shared between 10 beams and every token predicted by every beam is added to cacheSSSSSSSSSSELELELELELELELELELELECT
). So maybe it would make sense to change the _apply_logits_processors letting every beam to have its own processor, e.g.:Looking forward to your response, and I would be happy to implement the changes under your guidance!