KeyError in `BetterFSM::FSMInfo` when input FSM `alphabet` contains UTF-8 characters that ends with `\xb8\x80` #833

m0g1cian commented 2 months ago

Describe the issue as clearly as possible:

Update 2

Can confirm there's something wrong with Numba's Typed Dict implementation. Check issue here


After some testing, it is clear that this KeyError occurs for UTF-8 characters that ends with \xb8\x80 (e.g. "帀", "㸀", "渀", 縀").

When outlines builds BetterFSM from a reference FSM (e.g. from interegular), if the reference FSM contains Chinese character "一", the corresponding numba.typed.Dict used by BetterFSM::alphabet_symbol_map somehow converts this character into an empty string, causing a KeyError whenever __getitem__ is triggered .

Steps/code to reproduce the bug:


import interegular
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm

if __name__ == "__main__":
    regex_string = r"(1|2|3|one|two|three|一|二|三)"
    regex_pattern = interegular.parse_pattern(regex_string)
    regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
    fsm_info: FSMInfo = regex_fsm.fsm_info

Some insight:

print (k, v) in alphabet_symbol_mapping_items before create_fsm_info() (right after outlines.fsm.regex.py::96)

e 3
w 9
2 1
h 4
t 8
o 6
r 7
1 0
3 2
n 5
一 10
二 12
三 11

print (k, v) in alphabet_symbol_mapping_items in create_fsm_info() when building alphabet_symbol_map (right after outlines.fsm.regex.py::139)

e 3
w 9
2 1
h 4
t 8
o 6
r 7
1 0
3 2
n 5
二 12
三 11

Expected result:

I was able to get the expected result after tweaking two places:

  1. outlines.fsm.regex.py::112: change nb_unichar_2_type = numba.types.UnicodeCharSeq(2) to nb_unichar_2_type = numba.types.unicode_type
  2. outlines.fsm.regex.py::89: change alphabet_symbol_mapping_items to a simple python list alphabet_symbol_mapping_items = list((k,v) for k, v in self.alphabet._symbol_mapping.items() if k != anything_else)
  transitions=DictType[UniTuple(int64 x 2),int64]<iv=None>({(0, 0): 1, (0, 1): 1, (0, 2): 1, (0, 6): 2, (0, 8): 3, (0, 10): 1, (0, 11): 1, (0, 12): 1, (2, 5): 7, (3, 4): 4, (3, 9): 5, (4, 7): 6, (5, 6): 1, (6, 3): 7, (7, 3): 1}),
  trans_key_to_states=DictType[int64,ListType[int64]]<iv=None>({0: [0], 1: [0], 2: [0], 6: [0, 5], 8: [0], 10: [0], 11: [0], 12: [0], 5: [2], 4: [3], 9: [3], 7: [4], 3: [6, 7]}),
  alphabet_symbol_mapping=DictType[unicode_type,int64]<iv=None>({2: 1, 1: 0, o: 6, 3: 2, r: 7, 一: 10, n: 5, w: 9, e: 3, h: 4, 三: 11, 二: 12, t: 8})

Error message:

Traceback (most recent call last):
  File "...\debug_keyerror.py", line 9, in <module>
  File "...\lib\collections\__init__.py", line 441, in __repr__
    return self.__class__.__name__ + repr_fmt % self
  File "...\lib\site-packages\numba\typed\typeddict.py", line 217, in __repr__
    body = str(self)
  File "...\lib\site-packages\numba\typed\typeddict.py", line 212, in __str__
    for k, v in self.items():
  File "...\lib\_collections_abc.py", line 911, in __iter__
    yield (key, self._mapping[key])
  File "...\lib\site-packages\numba\typed\typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
  File "...\lib\site-packages\numba\typed\dictobject.py", line 783, in impl
    raise KeyError()

Outlines/Python version information:

Version information

Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:34:57) [MSC v.1936 64 bit (AMD64)]
outlines==0.0.40
numba==0.59.1
numpy==1.26.4

Context for the issue:

I not sure why only the Chinese character "一" breaks everything while other Chinese characters are working fine as far as I can tell.

lapp0 commented 2 months ago

@m0g1cian opened an upstream issue: https://github.com/numba/numba/issues/9542

Per the thread, it appears to be an upstream bug on the numba side due to UnicodeCharSeq having trouble handling leading null byte \x00.

There are a few options here:

import numba
import numpy as np
from numba.cpython.charseq import unicode_charseq_get_code

def function():
    s = np.empty(3, dtype="<U1")
    s[0] = "  ^`"
    s[1] = "  ^l"
    s[2] = "  ^i"
    return [unicode_charseq_get_code(item, 0) for item in s]

result = function()

Output: [19968, 20108, 32]

M0gician commented 2 months ago

I made a local patch to fix this issue in outlines. It basically makes numba typed Dict or List always use unicode_type rather than unicode_charseq

I'll make a PR soon.