dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
9.82k stars 501 forks source link

Using context-free grammars to guide generation does not work #959

Closed border-b closed 2 months ago

border-b commented 5 months ago

Describe the issue as clearly as possible:

I was trying to run the example of using cfg to guide generation. But it seems there is some issue with CFGGuide. Just running the example without any change produces an error.

Steps/code to reproduce the bug:

import outlines

arithmetic_grammar = """
    ?start: expression

    ?expression: term (("+" | "-") term)*

    ?term: factor (("*" | "/") factor)*

    ?factor: NUMBER
           | "-" factor
           | "(" expression ")"

    %import common.NUMBER
"""

model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
generator = outlines.generate.cfg(model, arithmetic_grammar)
sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:")

print(sequence)
# (8-2)

Expected result:

(8-2)

Error message:

Traceback (most recent call last)
Cell In[1], line 19
     17 model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
     18 generator = outlines.generate.cfg(model, arithmetic_grammar)
---> 19 sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:")
     21 print(sequence)
     22 # (8-2)

File /usr/local/lib/python3.10/site-packages/outlines/generate/api.py:207, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng)
    205 while True:
    206 try:
--> 207 last_state = next(states)
    208 if max_tokens or stop_sequences:
    209 token_ids = last_state.token_ids

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:92, in sequence_generator(model, sampler, fsms, token_ids, sequence_weights, attention_masks, fsm_states, rng)
     89 fsms = reorder_fsms(fsms, ancestors)
     90 fsm_states = reorder_fsm_states(fsm_states, ancestors)
---> 92 fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids)
     93 is_finished = is_generation_finished(fsms, fsm_states)
     95 if is_finished:

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:131, in get_next_fsm_states(fsms, fsm_states, next_token_ids)
    114 def get_next_fsm_states(
    115 fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
    116 ) -> List[int]:
    117 """
    118
    119 Parameters
   (...)
    129
    130 """
--> 131 return [
    132 fsm.get_next_state(fsm_state, int(token_id[0]))
    133 for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids)
    134 ]

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:132, in (.0)
    114 def get_next_fsm_states(
    115 fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
    116 ) -> List[int]:
    117 """
    118
    119 Parameters
   (...)
    129
    130 """
    131 return [
--> 132 fsm.get_next_state(fsm_state, int(token_id[0]))
    133 for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids)
    134 ]

File /usr/local/lib/python3.10/site-packages/outlines/fsm/guide.py:416, in CFGGuide.get_next_state(self, state, token_id)
    413 self.reset_state = False
    414 state = self.start_state
--> 416 return self.regex_fsm.get_next_state(state, token_id)

AttributeError: 'CFGGuide' object has no attribute 'regex_fsm'

Outlines/Python version information:

Version information

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
 - Avoid using `tokenizers` before the fork if possible
 - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
0.0.41
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
 - Avoid using `tokenizers` before the fork if possible
 - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Python 3.10.8 (main, Dec 6 2022, 14:24:03) [GCC 10.2.1 20210110]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
 - Avoid using `tokenizers` before the fork if possible
 - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
accelerate==0.30.1
aiohttp==3.8.3
aiosignal==1.3.1
aiostream==0.4.4
annotated-types==0.6.0
anyio==3.7.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asgiref==3.5.2
asttokens==2.2.1
async-lru==2.0.4
async-timeout==4.0.2
attrs==23.1.0
Babel==2.15.0
backcall==0.2.0
beautifulsoup4==4.12.3
bitsandbytes==0.43.1
bleach==6.1.0
bytecode==0.14.2
cattrs==23.1.2
certifi==2023.5.7
cffi==1.16.0
charset-normalizer==2.1.1
click==8.1.3
cloudpickle==2.0.0
comm==0.2.2
commonmark==0.9.1
datasets==2.19.1
ddsketch==2.0.4
ddtrace==1.5.2
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
docstring_parser==0.16
envier==0.4.0
exceptiongroup==1.1.1
executing==1.2.0
fastapi==0.88.0
fastjsonschema==2.19.1
fastprogress==1.0.0
filelock==3.14.0
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2024.3.1
grpclib==0.4.3
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.0
hyperframe==6.0.1
idna==3.4
importlib-metadata==4.8.1
interegular==0.3.3
ipykernel==6.29.4
ipython==8.14.0
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.4
json5==0.9.25
jsonpointer==2.4
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.8
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
lark==1.1.9
llvmlite==0.42.0
MarkupSafe==2.1.5
matplotlib-inline==0.1.6
mistune==3.0.2
modal==0.62.139
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.16
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
notebook==7.1.3
notebook_shim==0.2.4
numba==0.59.1
numpy==1.25.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.1.105
nvidia-nvtx-cu12==12.1.105
outlines==0.0.41
overrides==7.7.0
packaging==23.1
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.3
peft==0.10.0
pexpect==4.8.0
pickleshare==0.7.5
pillow==10.2.0
platformdirs==4.2.1
prometheus_client==0.20.0
prompt-toolkit==3.0.38
protobuf==3.20.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycparser==2.22
pydantic==2.7.1
pydantic_core==2.18.2
Pygments==2.15.1
pyrsistent==0.19.3
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
python-multipart==0.0.6
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.3
qtconsole==5.5.2
QtPy==2.4.1
referencing==0.35.1
regex==2024.5.10
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==12.3.0
rpds-py==0.18.1
safetensors==0.4.3
Send2Trash==1.8.3
sentencepiece==0.2.0
shtab==1.7.1
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.2
starlette==0.22.0
sympy==1.12
tblib==1.7.0
tenacity==8.2.2
terminado==0.18.1
tinycss2==1.3.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
torch==2.2.2+cu121
torchaudio==2.2.2+cu121
torchvision==0.17.2+cu121
tornado==6.4
tqdm==4.66.4
traitlets==5.9.0
transformers==4.40.2
triton==2.2.0
trl==0.8.6
typeguard==4.0.0
typer==0.6.1
types-certifi==2021.10.8.3
types-python-dateutil==2.9.0.20240316
types-toml==0.10.4
typing_extensions==4.9.0
tyro==0.8.4
tzdata==2024.1
unsloth @ git+https://github.com/unslothai/unsloth.git@47ffd39abd02338e8a5f226d0f529347fb7e5f89
uri-template==1.3.0
urllib3==2.2.1
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.10
xformers==0.0.25.post1
xmltodict==0.13.0
xxhash==3.4.1
yarl==1.9.2
zipp==3.15.0

Context for the issue:

I was trying to test the output with a custom grammar, but the provided example fails to generate any output.

rlouf commented 5 months ago

Can you upgrade outlines for 0.0.43 and try again?

border-b commented 5 months ago

@rlouf Upgrading to 0.0.43 solves this error. But it generates another one:

Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/lark/lexer.py:673, in ContextualLexer.lex(self, lexer_state, parser_state)
    672 last_token = lexer_state.last_token  # Save last_token. Calling root_lexer.next_token will change this to the wrong token
--> 673 token = self.root_lexer.next_token(lexer_state, parser_state)
    674 raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name)

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:598, in BasicLexer.next_token(self, lex_state, parser_state)
    597         allowed = {"<END-OF-FILE>"}
--> 598     raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
    599                                allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token],
    600                                state=parser_state, terminals_by_name=self.terminals_by_name)
    602 value, type_ = res

UnexpectedCharacters: No terminal matches 'e' in the current parser context, at line 1 col 193

98*0.5000000000000022*2.2204460492503131e-
                                        ^
Expected one of: 
    * RPAR
    * STAR
    * NUMBER
    * SLASH
    * PLUS
    * MINUS
    * LPAR

Previous tokens: Token('NUMBER', '2.2204460492503131')

During handling of the above exception, another exception occurred:

UnexpectedCharacters                      Traceback (most recent call last)
Cell In[2], line 19
     17 model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
     18 generator = outlines.generate.cfg(model, arithmetic_grammar)
---> 19 sequence = generator("Alice had 4 apples and Bob ate 2. Write an expression for Alice's apples:")
     21 print(sequence)
     22 # (8-2)

File /usr/local/lib/python3.10/site-packages/outlines/generate/api.py:207, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng)
    205 while True:
    206     try:
--> 207         last_state = next(states)
    208         if max_tokens or stop_sequences:
    209             token_ids = last_state.token_ids

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:80, in sequence_generator(model, sampler, fsms, token_ids, sequence_weights, attention_masks, fsm_states, rng)
     75 except IndexError:  # Exceeding the context length
     76     raise ContextLengthExceededError(
     77         "The input length exceeds the context length of the model."
     78     )
---> 80 allowed_tokens = get_allowed_tokens(fsms, fsm_states)
     81 biased_logits = bias_logits(logits, allowed_tokens)
     82 next_token_ids, ancestors, sequence_weights = sampler(
     83     biased_logits, sequence_weights, rng
     84 )

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:155, in get_allowed_tokens(fsms, fsm_states)
    138 def get_allowed_tokens(
    139     fsms: List["Guide"], fsm_states: List[int]
    140 ) -> List[Optional[Iterable[int]]]:
    141     """Get the new instructions for each sequence from the finite-state machine.
    142 
    143     Parameters
   (...)
    153 
    154     """
--> 155     return [
    156         fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states)
    157     ]

File /usr/local/lib/python3.10/site-packages/outlines/generate/generator.py:156, in <listcomp>(.0)
    138 def get_allowed_tokens(
    139     fsms: List["Guide"], fsm_states: List[int]
    140 ) -> List[Optional[Iterable[int]]]:
    141     """Get the new instructions for each sequence from the finite-state machine.
    142 
    143     Parameters
   (...)
    153 
    154     """
    155     return [
--> 156         fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states)
    157     ]

File /usr/local/lib/python3.10/site-packages/outlines/fsm/guide.py:349, in CFGGuide.get_next_instruction(self, state)
    346         self.regex_fsm_last = proposer
    348 interactive = self.parser.parse_interactive(self.generation)
--> 349 interactive.exhaust_lexer()
    351 options = {self.terminal_regexps[x] for x in interactive.accepts()}
    352 # add %ignore terminals

File /usr/local/lib/python3.10/site-packages/lark/parsers/lalr_interactive_parser.py:52, in InteractiveParser.exhaust_lexer(self)
     47 def exhaust_lexer(self) -> List[Token]:
     48     """Try to feed the rest of the lexer state into the interactive parser.
     49 
     50     Note that this modifies the instance in place and does not feed an '$END' Token
     51     """
---> 52     return list(self.iter_parse())

File /usr/local/lib/python3.10/site-packages/lark/parsers/lalr_interactive_parser.py:43, in InteractiveParser.iter_parse(self)
     35 def iter_parse(self) -> Iterator[Token]:
     36     """Step through the different stages of the parse, by reading tokens from the lexer
     37     and feeding them to the parser, one per iteration.
     38 
   (...)
     41     When the parse is over, the resulting tree can be found in ``InteractiveParser.result``.
     42     """
---> 43     for token in self.lexer_thread.lex(self.parser_state):
     44         yield token
     45         self.result = self.feed_token(token)

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:676, in ContextualLexer.lex(self, lexer_state, parser_state)
    674     raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name)
    675 except UnexpectedCharacters:
--> 676     raise e

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:665, in ContextualLexer.lex(self, lexer_state, parser_state)
    663     while True:
    664         lexer = self.lexers[parser_state.position]
--> 665         yield lexer.next_token(lexer_state, parser_state)
    666 except EOFError:
    667     pass

File /usr/local/lib/python3.10/site-packages/lark/lexer.py:598, in BasicLexer.next_token(self, lex_state, parser_state)
    596     if not allowed:
    597         allowed = {"<END-OF-FILE>"}
--> 598     raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
    599                                allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token],
    600                                state=parser_state, terminals_by_name=self.terminals_by_name)
    602 value, type_ = res
    604 ignored = type_ in self.ignore_types

UnexpectedCharacters: No terminal matches 'e' in the current parser context, at line 1 col 193

98*0.5000000000000022*2.2204460492503131e-
                                        ^
Expected one of: 
    * RPAR
    * STAR
    * SLASH
    * PLUS
    * MINUS

Previous tokens: Token('NUMBER', '2.2204460492503131')

It seems the function is executing now, but the output is not following the grammar?

lapp0 commented 5 months ago

This seems to be the same as the issue I ran into in https://github.com/outlines-dev/outlines/issues/796

I'll be working on getting CFG and the parser in a good state over the coming weeks. You can track progress by subscribing to this issue: https://github.com/outlines-dev/outlines/issues/684