outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
7.18k stars 370 forks source link

KeyError in `BetterFSM::FSMInfo` when input FSM `alphabet` contains UTF-8 characters that ends with `\xb8\x80` #833

Closed m0g1cian closed 1 month ago

m0g1cian commented 2 months ago

Describe the issue as clearly as possible:

Update 2

Can confirm there's something wrong with Numba's Typed Dict implementation. Check issue here

Update

After some testing, it is clear that this KeyError occurs for UTF-8 characters that ends with \xb8\x80 (e.g. "帀", "㸀", "渀", 縀").


When outlines builds BetterFSM from a reference FSM (e.g. from interegular), if the reference FSM contains Chinese character "一", the corresponding numba.typed.Dict used by BetterFSM::alphabet_symbol_map somehow converts this character into an empty string, causing a KeyError whenever __getitem__ is triggered .

Steps/code to reproduce the bug:

debug_keyerror.py

import interegular
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm

if __name__ == "__main__":
    regex_string = r"(1|2|3|one|two|three|一|二|三)"
    regex_pattern = interegular.parse_pattern(regex_string)
    regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
    fsm_info: FSMInfo = regex_fsm.fsm_info
    print(fsm_info)

Some insight:

print (k, v) in alphabet_symbol_mapping_items before create_fsm_info() (right after outlines.fsm.regex.py::96)

e 3
w 9
2 1
h 4
t 8
o 6
r 7
1 0
3 2
n 5
一 10
二 12
三 11

print (k, v) in alphabet_symbol_mapping_items in create_fsm_info() when building alphabet_symbol_map (right after outlines.fsm.regex.py::139)

e 3
w 9
2 1
h 4
t 8
o 6
r 7
1 0
3 2
n 5
 10
二 12
三 11

Expected result:

I was able to get the expected result after tweaking two places:

  1. outlines.fsm.regex.py::112: change nb_unichar_2_type = numba.types.UnicodeCharSeq(2) to nb_unichar_2_type = numba.types.unicode_type
  2. outlines.fsm.regex.py::89: change alphabet_symbol_mapping_items to a simple python list alphabet_symbol_mapping_items = list((k,v) for k, v in self.alphabet._symbol_mapping.items() if k != anything_else)
FSMInfo(
  initial=0,
  finals={1},
  transitions=DictType[UniTuple(int64 x 2),int64]<iv=None>({(0, 0): 1, (0, 1): 1, (0, 2): 1, (0, 6): 2, (0, 8): 3, (0, 10): 1, (0, 11): 1, (0, 12): 1, (2, 5): 7, (3, 4): 4, (3, 9): 5, (4, 7): 6, (5, 6): 1, (6, 3): 7, (7, 3): 1}),
  trans_key_to_states=DictType[int64,ListType[int64]]<iv=None>({0: [0], 1: [0], 2: [0], 6: [0, 5], 8: [0], 10: [0], 11: [0], 12: [0], 5: [2], 4: [3], 9: [3], 7: [4], 3: [6, 7]}),
  alphabet_anything_value=13,
  alphabet_symbol_mapping=DictType[unicode_type,int64]<iv=None>({2: 1, 1: 0, o: 6, 3: 2, r: 7, 一: 10, n: 5, w: 9, e: 3, h: 4, 三: 11, 二: 12, t: 8})
)

Error message:

Traceback (most recent call last):
  File "...\debug_keyerror.py", line 9, in <module>
    print(fsm_info)
  File "...\lib\collections\__init__.py", line 441, in __repr__
    return self.__class__.__name__ + repr_fmt % self
  File "...\lib\site-packages\numba\typed\typeddict.py", line 217, in __repr__
    body = str(self)
  File "...\lib\site-packages\numba\typed\typeddict.py", line 212, in __str__
    for k, v in self.items():
  File "...\lib\_collections_abc.py", line 911, in __iter__
    yield (key, self._mapping[key])
  File "...\lib\site-packages\numba\typed\typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
  File "...\lib\site-packages\numba\typed\dictobject.py", line 783, in impl
    raise KeyError()
KeyError

Outlines/Python version information:

Version information

``` > \> python -c "from outlines import _version; print(_version.version)" 0.0.40 > \> python -c "import sys; print('Python', sys.version)" Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:34:57) [MSC v.1936 64 bit (AMD64)] > \> pip freeze annotated-types==0.6.0 anyio==4.3.0 attrs==23.2.0 boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1677499911949/work Brotli @ file:///D:/bld/brotli-split_1693583621767/work build==1.2.1 CacheControl==0.14.0 cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi cffi @ file:///D:/bld/cffi_1671179506518/work charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1688813409104/work cleo==2.1.0 cloudpickle==3.0.0 colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work conda==23.3.1 conda-libmamba-solver @ file:///home/conda/feedstock_root/build_artifacts/conda-libmamba-solver_1680508672016/work/src conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1691048088238/work conda_package_streaming @ file:///home/conda/feedstock_root/build_artifacts/conda-package-streaming_1691009212940/work crashtest==0.4.1 cryptography @ file:///D:/bld/cryptography-split_1691444290667/work diskcache==5.6.3 distlib==0.3.8 distro==1.9.0 docopt==0.6.2 dulwich==0.21.7 exceptiongroup==1.2.0 fastjsonschema==2.19.1 filelock==3.13.1 fsspec==2024.2.0 h11==0.14.0 h5py @ file:///D:/bld/h5py_1702471423597/work httpcore==1.0.5 httpx==0.27.0 huggingface-hub==0.21.3 idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work importlib_metadata==7.1.0 inquirerpy==0.3.4 installer==0.7.0 interegular==0.3.3 jaraco.classes==3.4.0 Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work joblib==1.4.0 jsonpatch @ file:///home/conda/feedstock_root/build_artifacts/jsonpatch_1632759296524/work jsonpointer==2.0 jsonschema==4.21.1 jsonschema-specifications==2023.12.1 keyring==24.3.1 lark==1.1.9 libmambapy @ file:///D:/bld/mamba-split_1680791188848/work/libmambapy llvmlite==0.42.0 mamba @ file:///D:/bld/mamba-split_1680791188848/work/mamba MarkupSafe @ file:///D:/bld/markupsafe_1706900062361/work menuinst @ file:///D:/bld/menuinst_1666839998718/work more-itertools==10.2.0 mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work msgpack==1.0.8 mypy==1.9.0 mypy-extensions==1.0.0 nest-asyncio==1.6.0 networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work numba==0.59.1 numpy @ file:///D:/bld/numpy_1707225570061/work/dist/numpy-1.26.4-cp310-cp310-win_amd64.whl#sha256=6761da75b1528684e6bf4dabdbdded9d1eb4d0e9b299482c7ce152cfb3155106 openai==1.21.2 outlines==0.0.40 packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work parse @ file:///home/conda/feedstock_root/build_artifacts/parse_1706516706584/work pexpect==4.9.0 pfzy==0.3.4 pipreqs==0.4.13 pkginfo==1.10.0 platformdirs==4.2.0 pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1693086607691/work poetry==1.8.2 poetry-core==1.9.0 poetry-plugin-export==1.7.1 prompt-toolkit==3.0.43 ptyprocess==0.7.0 pycosat @ file:///D:/bld/pycosat_1666836675990/work pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work pydantic==2.7.0 pydantic_core==2.18.1 pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1685514481738/work pyproject_hooks==1.0.0 PySocks @ file:///D:/bld/pysocks_1661604991356/work pywin32-ctypes==0.2.2 PyYAML @ file:///D:/bld/pyyaml_1695373629531/work rapidfuzz==3.8.1 referencing==0.34.0 regex @ file:///D:/bld/regex_1703393598862/work requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work requests-toolbelt==1.0.0 rpds-py==0.18.0 ruamel.yaml @ file:///D:/bld/ruamel.yaml_1686994025923/work ruamel.yaml.clib @ file:///D:/bld/ruamel.yaml.clib_1670412994006/work safetensors==0.4.3 scipy==1.13.0 sglang==0.1.14 shellingham==1.5.4 sniffio==1.3.1 sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180539862/work tokenizers==0.19.1 tomli==2.0.1 tomlkit==0.12.4 toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1657485559105/work torch==2.2.2 tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1691671248568/work transformers==4.40.0 trove-classifiers==2024.4.10 typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1708904622550/work urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1689789803562/work virtualenv==20.25.3 wcwidth==0.2.13 win-inet-pton @ file:///D:/bld/win_inet_pton_1667051142467/work yarg==0.1.9 zipp==3.18.1 zstandard==0.19.0 ```

Context for the issue:

I not sure why only the Chinese character "一" breaks everything while other Chinese characters are working fine as far as I can tell.

lapp0 commented 2 months ago

@m0g1cian opened an upstream issue: https://github.com/numba/numba/issues/9542

Per the thread, it appears to be an upstream bug on the numba side due to UnicodeCharSeq having trouble handling leading null byte \x00.

There are a few options here:

import numba
import numpy as np
from numba.cpython.charseq import unicode_charseq_get_code

@numba.njit
def function():
    s = np.empty(3, dtype="<U1")
    s[0] = "  ^`"
    s[1] = "  ^l"
    s[2] = "  ^i"
    return [unicode_charseq_get_code(item, 0) for item in s]

result = function()
print(result)

Output: [19968, 20108, 32]

M0gician commented 2 months ago

@m0g1cian opened an upstream issue: https://github.com/numba/numba/issues/9542

Per the thread, it appears to be an upstream bug on the numba side due to UnicodeCharSeq having trouble handling leading null byte \x00.

There are a few options here:

import numba
import numpy as np
from numba.cpython.charseq import unicode_charseq_get_code

@numba.njit
def function():
    s = np.empty(3, dtype="<U1")
    s[0] = "  ^`"
    s[1] = "  ^l"
    s[2] = "  ^i"
    return [unicode_charseq_get_code(item, 0) for item in s]

result = function()
print(result)

Output: [19968, 20108, 32]

I made a local patch to fix this issue in outlines. It basically makes numba typed Dict or List always use unicode_type rather than unicode_charseq

I'll make a PR soon.