eth-sri / lmql

A language for constraint-guided and efficient LLM programming.
https://lmql.ai
Apache License 2.0
3.6k stars 194 forks source link

STOPS_BEFORE on multiple tokens retains the first token #57

Open JasperDekoninck opened 1 year ago

JasperDekoninck commented 1 year ago

The following query returns "What did the fish say when" instead of the expected "What did the fish say".

argmax(max_len=80)
   """A list of good dad jokes. A indicates the punchline
   Q: How does a penguin build its house?
   A: Igloos it together.
   Q: Which knight invented King Arthur's Round Table?
   A: Sir Cumference.
   Q:[JOKE]"""
from
   "openai/text-davinci-003"
where
   STOPS_BEFORE(JOKE, "when it hit")
lbeurerkellner commented 1 year ago

This seems to be a bug with the OpenAI API and its "stop" parameter. The API documentation specifies, that stopping phrases will be removed from the response, but in this case only "it hit" is removed.

As a fix, we can disable LMQL's use of this parameter. Then the query works as intended:

argmax(max_len=80, chatty_openai=True, openai_nonstop=True)
   """A list of good dad jokes. A indicates the punchline
   Q: How does a penguin build its house?
   A: Igloos it together.
   Q: Which knight invented King Arthur's Round Table?
   A: Sir Cumference.
   Q:[JOKE]"""
from
   "openai/text-davinci-003"
where
   STOPS_BEFORE(JOKE, "when it hit")

# OUTPUT: What did the fish say

I currently see no way of fixing this, as it is impossible for us to tell if the stopping phrase was truncated correctly, as the original sequence may have also ended on "when when it hit". Using the openai_nonstop=True option is the only workaround right now.

JasperDekoninck commented 1 year ago

I am not sure if that is the case, since when I run the query with the openai API the text parameter does not contain "when". However, the logprobs do still contain "when" as a token, and maybe LMQL uses that to to append it to the prompt? Below the raw output I get from OpenAI:

{'object': 'text_completion',
 'created': 1683320740,
 'model': 'text-davinci-003',
 'choices': [{'text': ' What did the fish say ',
   'index': 0,
   'logprobs': {'tokens': [' What', ' did', ' the', ' fish', ' say', ' when'],
    'token_logprobs': [-0.24742532,
     -0.21822312,
     -0.15329956,
     -0.6100834,
     -0.00090957596,
     -0.00039830978],
    'top_logprobs': [{' What': -0.24742532},
     {' did': -0.21822312},
     {' the': -0.15329956},
     {' fish': -0.6100834},
     {' say': -0.00090957596},
     {' when': -0.00039830978}],
    'text_offset': [214, 219, 223, 227, 232, 236]},
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 71, 'completion_tokens': 6, 'total_tokens': 77}}

I also notice, that when I run the query with the API but with the stop word "when", the raw output is exactly the same (also in streaming form), so is more likely some problem at the LMQL side?

lbeurerkellner commented 1 year ago

Good observation. So then we can actually fix it, we just need to make sure to not consume "tokens" beyond what the "text" return value contains.

I suspect, this will also include the case, where the last token that is produced is detokenized and then truncated. E.g. "tokens" may include ",\n" as token, but "text" already ends before "\n", leading to a ","-terminated string. In these cases, we need to retokenize "text", to make sure a truncated (sub)token is replaced by its truncated version, e.g. return "," as last token and not the combined ",\n".

JasperDekoninck commented 1 year ago

You are right, the following also produces the mistake:

argmax
   """A list of good dad jokes. A indicates the punchline
   Q: How does a penguin build its house?
   A: Igloos it together.
   Q: Which knight invented King Arthur's Round Table?
   A: Sir Cumference.
   Q:[JOKE]"""
from
   "openai/text-davinci-003"
where
   len(JOKE) < 120 and 
   STOPS_BEFORE(JOKE, "id ")

Note the space at the end of "id ", this is necessary for the bug to appear, since " " is part of the token and if we just use "id", LMQL works as expected.