eth-sri / lmql

A language for constraint-guided and efficient LLM programming.
https://lmql.ai
Apache License 2.0
3.62k stars 195 forks source link

Dataclass has trouble parsing JSON string fields with newlines as well as double-quotes #244

Open KristianMischke opened 11 months ago

KristianMischke commented 11 months ago

I've been loving the dataclass parsing features, but I ran into this issue where the library errors when parsing string values that contain double-quotes or newline characters

This is my dataclass and the query that is used:

@dataclass
class CardData:
    name: str
    mana_cost: str
    supertypes: str
    types: str
    subtypes: str
    rules: str
    flavor: str
    # Creature
    attack: str
    defense: str
    # Planeswalker
    loyalty: str

@lmql.query
def convert_to_object(card_text):
    """lmql
    "{card_text}\n"
    "Structured: [CARD_DATA]\n" where type(CARD_DATA) is CardData

    CARD_DATA
    """

This is an example culprit value for card_text:

Yktlash, the Unseen Empress

{3}{B}{B}{G}{G}
Legendary Creature — Elf Shaman (5/5)
Trample

When Yktlash, the Unseen Empress enters the battlefield, create three 1/1 black and green Elf Warrior creature tokens.

{B}{G}, Sacrifice an Elf: Target creature gets -3/-3 until end of turn.

> 'In the depths of the forest, her reign remains elusive. Only the echoes of her whispers reveal the true power she wields.'

This was printed to the console. And you can see the newlines in the rules field that is likely the issue.

Failed to parse JSON from {"name":"Yktlash, the Unseen Empress","mana_cost":"{3}{B}{B}{G}{G}","supertypes":"Legendary","types":"Creature","subtypes":"Elf Shaman","rules":"Trample
When Yktlash, the Unseen Empress enters the battlefield, create three 1/1 black and green Elf Warrior creature tokens.
{B}{G}, Sacrifice an Elf: Target creature gets -3/-3 until end of turn.","flavor":"In the depths of the forest, her reign remains elusive. Only the echoes of her whispers reveal the true power she wields.","attack":"5","defense":"5","loyalty":"0"}

Then this exception was thrown:

Traceback (most recent call last):
  File "D:\REDACTED\src\main.py", line 98, in <module>
    result: LMQLResult = convert_to_object(card_rules)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\api\queries.py", line 148, in lmql_query_wrapper
    return module.query(*args, **kwargs)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\lmql_runtime.py", line 204, in __call__
    return call_sync(self, *args, **kwargs)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\loop.py", line 37, in call_sync
    res = loop.run_until_complete(task)
  File "C:\Users\krist\AppData\Local\Programs\Python\Python310\lib\asyncio\base_events.py", line 641, in run_until_complete
    return future.result()
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\lmql_runtime.py", line 230, in __acall__
    results = await interpreter.run(self.fct, **query_kwargs)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\tracing\tracer.py", line 240, in wrapper
    return await fct(*args, **kwargs)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\interpreter.py", line 1070, in run
    async for _ in decoder_fct(prompt, **decoder_args):
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\dclib\decoders.py", line 22, in argmax
    h = await model.rewrite(h, noscore=True)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\dclib\dclib_rewrite.py", line 194, in rewrite
    result_items = await asyncio.gather(*[op_rewrite(path, seqs) for path, seqs in ar.sequences.items()])
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\dclib\dclib_rewrite.py", line 191, in op_rewrite
    return path, await self._rewrite_seq(seqs, noscore=noscore)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\dclib\dclib_rewrite.py", line 64, in _rewrite_seq
    rewritten_ids = await rewriter(seqs, mask)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\interpreter.py", line 905, in rewrite_processor
    results = await asyncio.gather(*[self.rewrite_for_sequence(s, needs_rewrite) for s,needs_rewrite in zip(seqs, mask_seq_to_rewrite)])
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\interpreter.py", line 719, in rewrite_for_sequence
    result: RewrittenInputIds = await si.rewrite_for_sequence(seq, needs_rewrite, assert_no_advance=assert_no_advance)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\interpreter.py", line 805, in rewrite_for_sequence
    state = await self.advance(state)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\interpreter.py", line 434, in advance
    await query_head.advance(None)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\multi_head_interpretation.py", line 89, in advance
    await self.handle_current_arg()
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\multi_head_interpretation.py", line 112, in handle_current_arg
    await self.advance(res)
  File "C:\Users\krist\AppData\Local\pypoetry\Cache\virtualenvs\REDACTED-cM1rhShM-py3.10\lib\site-packages\lmql\runtime\multi_head_interpretation.py", line 88, in advance
    self.current_args = await self.iterator_fct().asend(result)
  File "C:\Users\krist\AppData\Local\Temp\tmpd1cqlv38_compiled.py", line 101, in query
    yield ('result', await lmql.runtime_support.call(type_dict_to_type_instance, json_payload, ty))
UnboundLocalError: local variable 'json_payload' referenced before assignment
lbeurerkellner commented 11 months ago

I just pushed a fix to main which resolves it. Thanks for reporting it, @dataclass support is still in Preview, so this kind of feedback is very helpful.

KristianMischke commented 11 months ago

@lbeurerkellner Awesome, thanks for that! Do you have an official ingress for feedback?

Because I think that @dataclass support is super powerful, but could use some enhancements:

I could imagine something like this perhaps

@dataclass
class MyClass:
    @lmql.field("""where classification in ["positive", "negative", "neutral"]""")
    classification: str

All of these things can already be done if writing a prompt formatted in JSON for example, and then writing the python code to convert the JSON object to python. So using @dataclasses is really convenient compared to writing a JSON prompt and doing the conversion because it's already in the target format