eth-sri / lmql

A language for constraint-guided and efficient LLM programming.
https://lmql.ai
Apache License 2.0
3.47k stars 191 forks source link

[BUG] sub-queries fail with temperature=0 #328

Open gamendez98 opened 4 months ago

gamendez98 commented 4 months ago

I am running this code

class Output:
   def __init__(self):
      self.output_vars = {}

   def add(self, vvar, value):
      if vvar in self.output_vars:
         self.output_vars[vvar].append(value)
      else:
         self.output_vars[vvar] = [value]

   def add_all(self, vvar, values):
      if vvar in self.output_vars:
         self.output_vars[vvar].extend(values)
      else:
         self.output_vars[vvar] = list(values)

self = Output()

@lmql.query
async def make_summary_option(self, option_text, length_limit, option_type):
   '''
   "{:user}Make a {option_text}"
   "{:assistant}[option]" where STOPS_AT(option, nl) and STOPS_AT(option, ".")
   if len(option) > length_limit:
      "{:user}The {option_text} is too long, make it shorter"
      "{:assistant}[option]" where STOPS_AT(option, nl) and STOPS_AT(option, ".")
   self.add(option_type, option.strip())
   '''

"""{:system}The user wants to make an activity for the instructions `Select the sentences that are true in the text`.
To do it they will ask you for `correct answers` or `distractors` for a particular text.
Here is the text:

{material_text}

If you are asked for a correct answer more than once make them different.
Only answer with the answer or distractor, nothing more.
Your responses should be shorter than 100 characters.
Avoid repetitive language.
"""
length_limit = 120
for i in range(n_summaries):
   '[correct_summaries: make_summary_option(self, "correct answer", length_limit, "correct_summaries")]'
for i in range(n_distractors):
   '[incorrect_summaries: make_summary_option(self, "distractor", length_limit, "incorrect_summaries")]'

and I get the following error

Traceback (most recent call last):
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3546, in run_code
    await eval(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-d00d73af9de3>", line 1, in <module>
    await tg.generate(material_text=text, n_summaries=2, n_distractors=4)
  File "/home/gustavom/Documents/slang-nlp/main/text_generation/text_generator.py", line 42, in generate
    result = await self.prompt(*(prompt_parameters[name] for name in self.prompt_arguments))
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 230, in __acall__
    results = await interpreter.run(self.fct, **query_kwargs)
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/tracing/tracer.py", line 240, in wrapper
    return await fct(*args, **kwargs)
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 1070, in run
    async for _ in decoder_fct(prompt, **decoder_args):
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/dclib/decoders.py", line 39, in sample
    h = await model.rewrite(h, noscore=True)
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_rewrite.py", line 194, in rewrite
    result_items = await asyncio.gather(*[op_rewrite(path, seqs) for path, seqs in ar.sequences.items()])
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_rewrite.py", line 191, in op_rewrite
    return path, await self._rewrite_seq(seqs, noscore=noscore)
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_rewrite.py", line 64, in _rewrite_seq
    rewritten_ids = await rewriter(seqs, mask)
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 905, in rewrite_processor
    results = await asyncio.gather(*[self.rewrite_for_sequence(s, needs_rewrite) for s,needs_rewrite in zip(seqs, mask_seq_to_rewrite)])
  File "/home/gustavom/Documents/slang-nlp/venv/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 773, in rewrite_for_sequence
    if result_state.query_head.result is not None:
AttributeError: 'NoneType' object has no attribute 'query_head'

It only happens if I try to call the subquery more than once, that is in this case if n_summaries + n_distractors >=2

lbeurerkellner commented 3 months ago

Thanks for reporting this. This seems to be a bug. I have investigated this for a bit now, and one workaround that could work is to add a space in front of every summary variable, i.e. ' [incorrect_summaries: ...]' instead of '[incorrect_summaries: ...]'. The issue occurs with back-to-back variables and nested queries.