dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
9.28k stars 472 forks source link

Very slow crawl in interegular, scalability issue #680

Open viktor-ferenczi opened 8 months ago

viktor-ferenczi commented 8 months ago

Describe the issue as clearly as possible:

Observed that the vLLM server gets stuck indefinitely. Running in the debugger I could stop where it was frozen. It is running this while loop infinitely, because it always appends a new item to states, therefore it can never finish the loop (the list is just getting longer all the time):

def crawl(alphabet, initial, final, follow): """ Given the above conditions and instructions, crawl a new unknown FSM, mapping its states, final states and transitions. Return the new FSM. This is a pretty powerful procedure which could potentially go on forever if you supply an evil version of follow(). """

states = [initial]
finals = set()
map = {}

# iterate over a growing list
i = 0
while i < len(states):
    state = states[i]

    # add to finals
    if final(state):
        finals.add(i)

    # compute map for this state
    map[i] = {}
    for transition in alphabet.by_transition:
        try:
            next = follow(state, transition)
        except OblivionError:
            # Reached an oblivion state. Don't list it.
            continue
        else:
            try:
                j = states.index(next)
            except ValueError:
                j = len(states)
                states.append(next)
            map[i][transition] = j

    i += 1

return FSM(
    alphabet=alphabet,
    states=range(len(states)),
    initial=0,
    finals=finals,
    map=map,
    __no_validation__=True,
)

Steps/code to reproduce the bug:

The regex used via vLLM's API:

(Path: `Shop\.Service/OrderService\.cs`\n\n\n(\n|[^`].*?\n)*\n\n)?(Path: `Shop\.Web/Pages/Order/Archive\.cshtml\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Web/Controllers/OrderController\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Web/Controllers/AccountController\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Data/Enums/OrderBy\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Data/IOrder\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(New: `(.*?)`\n\n```([a-z]+)\n(\n|[^`].*?\n)*```\n\n)?(New: `(.*?)`\n\n```([a-z]+)\n(\n|[^`].*?\n)*```\n\n)?(New: `(.*?)`\n\n```([a-z]+)\n(\n|[^`].*?\n)*```\n\n)?END\n

Variable values in the infinite loop state At the line: states.append(next)

i = 22100
j = 22115
transition = 11

next = frozenset({(1, 55), (2, 65), (3, 68), (5, 66), (6, 38), (6, 53), (7, 27), (8, 27), (9, 27)})

alphabet.by_transition = {0: ['r'], 1: ['n'], 2: ['B'], 3: ['a'], 4: ['u'], 5: ['c'], 6: ['l'], 7: ['f', 'x', 'q', 'j', 'k', 'z'], 8: ['y'], 9: ['h'], 10: [' '], 11: ['`'], 12: ['i'], 13: ['t'], 14: ['O'], 15: ['b'], 16: ['g'], 17: ['A'], 18: ['S'], 19: ['v'], 20: [':'], 21: ['D'], 22: ['o'], 23: ['m'], 24: ['W'], 25: ['d'], 26: ['P'], 27: ['/'], 28: ['C'], 29: ['.'], 30: ['\n'], 31: ['I'], 32: [anything_else], 33: ['E'], 34: ['e'], 35: ['N'], 36: ['s'], 37: ['w'], 38: ['p']}

states = [..., frozenset({(4, 39), (7, 20), (3, 46), (5, 59), (6, 46), (8, 20), (2, 58), (3, 61), (9, 20)}), frozenset({(4, 39), (5, 59), (6, 46), (2, 58), (3, 61), (6, 17), (4, 17), (1, 48)}), frozenset({(4, 39), (2, 58), (2, 17), (5, 59), (6, 46), (1, 48), (3, 61)}), frozenset({(4, 39), (3, 17), (5, 59), (6, 46), (5, 17), (2, 58), (3, 61), (1, 48)}), frozenset({(4, 45), (7, 26), (8, 26), (2, 64), (5, 65), (6, 52), (3, 67), (1, 54), (9, 26)}), frozenset({(7, 27), (5, 66), (8, 27), (3, 68), (2, 65), (6, 53), (1, 55), (9, 27), (4, 46)}), frozenset({(5, 60), (6, 47), (3, 62), (1, 49), (9, 21), (4, 40), (3, 57), (7, 21), (8, 21)}), frozenset({(4, 39), (7, 20), (5, 59), (8, 20), (2, 45), (2, 58), (3, 61), (1, 48), (9, 20)}), frozenset({(4, 39), (7, 20), (5, 59), (8, 20), (3, 61), (2, 58), (5, 45), (1, 48), (9, 20)}), frozenset({(4, 39), (7, 20), (5, 59), (8, 20), (2, 58), (3, 61), (3, 45), (1, 48), (9, 20)})]

Expected result:

Letting vLLM to produce matching content. The prompt is instructing the model to do so and it worked before with a less string regex without the actual file names.

Error message:

None, vLLM freezes with 100% core load. It may not be completely frozen, just very-very slow. The GPU load according to `nvidia-smi` is zero, therefore vLLM cannot make any progress.

Outlines/Python version information:

Version information

``` outlines local dev install from latest `main` branch `e99d92d0` packaging==23.2 paginate==0.5.6 pandas==2.2.0 pathspec==0.12.1 perscache==0.6.1 pillow==10.2.0 platformdirs==4.2.0 pluggy==1.4.0 protobuf==4.25.2 psutil==5.9.8 pyarrow==15.0.0 pyarrow-hotfix==0.6 pycparser==2.21 pydantic==2.6.0 pydantic_core==2.16.1 Pygments==2.17.2 pymdown-extensions==10.7 pynvml==11.5.0 pytest==8.0.0 python-dateutil==2.8.2 python-dotenv==1.0.1 pytz==2024.1 PyYAML==6.0.1 pyyaml_env_tag==0.1 quantile-python==1.1 ray==2.9.1 referencing==0.33.0 regex==2023.12.25 requests==2.31.0 rpds-py==0.17.1 safetensors==0.4.2 scipy==1.12.0 sentencepiece==0.1.99 six==1.16.0 sniffio==1.3.0 starlette==0.35.1 sympy==1.12 tinycss2==1.2.1 tokenizers==0.15.1 tomli==2.0.1 torch==2.1.2 tqdm==4.66.1 transformers==4.37.2 triton==2.1.0 typing_extensions==4.9.0 tzdata==2023.4 urllib3==2.2.0 uvicorn==0.27.0.post1 uvloop==0.19.0 vllm==0.3.1 watchdog==4.0.0 watchfiles==0.21.0 webencodings==0.5.1 websockets==12.0 xformers==0.0.23.post1 xxhash==3.4.1 yarl==1.9.4 ```

Context for the issue:

I'm just running some heavy code understanding and generation workload through outlines.

viktor-ferenczi commented 8 months ago

@lapp0 Frozen FSM, maybe a low hanging fruit. But I don't know the outlines code enough to see the problem right away. The regex seems to be good, but a bit long. Still, should not cause a complete freeze.

lapp0 commented 8 months ago

Making long expressions efficient is a work in progress. interegular isn't quite there for complex expressions.

Related: https://github.com/outlines-dev/outlines/issues/658

Your expression results in a very large FSM. Before we even crawling the FSM to create a token index, interegular.parse_pattern().to_fsm() takes over 10 minutes to compile, and FSM.reduce() takes over 30 minutes I don't think it's getting into an infinite loop, this is just a very expensive operation.

Detailed profile for the below code follows

import interegular

def profile_fsm_construction(pattern_str):
    pat = interegular.parse_pattern(pattern_str)
    fsm = pat.to_fsm()
    reduced = fsm.reduce()

pattern = """(Path: `Shop\.Service/OrderService\.cs`\n\n\n(\n|[^`].*?\n)*\n\n)?(Path: `Shop\.Web/Pages/Order/Archive\.cshtml\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Web/Controllers/OrderController\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Web/Controllers/AccountController\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Data/Enums/OrderBy\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(Path: `Shop\.Data/IOrder\.cs`\n\n```cs\n(\n|[^`].*?\n)*```\n\n)?(New: `(.*?)`\n\n```([a-z]+)\n(\n|[^`].*?\n)*```\n\n)?(New: `(.*?)`\n\n```([a-z]+)\n(\n|[^`].*?\n)*```\n\n)?(New: `(.*?)`\n\n```([a-z]+)\n(\n|[^`].*?\n)*```\n\n)?END\n"""

if __name__ == "__main__":
    import pstats
    import cProfile

    cProfile.run('profile_fsm_construction(pattern)', 'profile_stats')
    p = pstats.Stats('profile_stats')
    p.sort_stats('cumtime').print_stats()
``` ncalls tottime percall cumtime percall filename:lineno(function) 1 0.000 0.000 2637.640 2637.640 {built-in method builtins.exec} 1 0.018 0.018 2637.640 2637.640 :1() 1 0.000 0.000 2637.622 2637.622 /home/andrew/p/outlines/profile_fsm.py:4(profile_fsm_construction) 262 41.940 0.160 2636.706 10.064 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:969(crawl) 1 0.012 0.012 1837.565 1837.565 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:249(reduce) 2 0.706 0.353 1837.553 918.776 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:558(reversed) 170625 1043.104 0.006 1757.005 0.010 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:580(follow) 1948533 827.305 0.000 827.305 0.000 {method 'index' of 'list' objects} 25/1 0.010 0.000 800.049 800.049 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:447(to_fsm) 25 0.000 0.000 641.479 25.659 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:451(union) 25 0.003 0.000 641.479 25.659 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:913(parallel) 2635780375 402.243 0.000 402.243 0.000 {method 'get' of 'dict' objects} 2640067274 312.260 0.000 312.260 0.000 {method 'update' of 'set' objects} 59/2 0.000 0.000 158.560 79.280 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:453() 34/1 0.001 0.000 158.560 158.560 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:370(to_fsm) 67 0.006 0.000 158.487 2.365 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:310(concatenate) 946959 5.441 0.000 7.330 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:347(follow) 924417 2.801 0.000 2.801 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:925(follow) 4286906 1.285 0.000 1.287 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:319(connect_all) 23703 0.078 0.000 0.158 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:939(final) 33/9 0.000 0.000 0.156 0.017 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:280(to_fsm) 1019344 0.113 0.000 0.113 0.000 {method 'add' of 'set' objects} 23703 0.064 0.000 0.064 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:940() 136046 0.060 0.000 0.060 0.000 {built-in method builtins.len} 151 0.007 0.000 0.042 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:112(union) 33 0.000 0.000 0.038 0.001 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:364(__add__) 24281 0.033 0.000 0.033 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:340(final) 151 0.005 0.000 0.030 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:114() 110 0.000 0.000 0.028 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:409(times) 42 0.000 0.000 0.026 0.001 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:445(__mul__) 501/1 0.001 0.000 0.025 0.025 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:69(get_alphabet) 25/1 0.000 0.000 0.025 0.025 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:423(_get_alphabet) 59/2 0.000 0.000 0.025 0.012 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:425() 34/1 0.000 0.000 0.025 0.025 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:330(_get_alphabet) 42422 0.008 0.000 0.025 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:114() 53097 0.024 0.000 0.024 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:71(by_transition) 467/14 0.000 0.000 0.024 0.002 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:331() 33/9 0.000 0.000 0.024 0.003 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:270(_get_alphabet) 40412 0.009 0.000 0.017 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:100(__getitem__) 23864 0.016 0.000 0.016 0.000 {built-in method builtins.any} 76080 0.015 0.000 0.015 0.000 {method 'append' of 'list' objects} 22503 0.013 0.000 0.015 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:429(follow) 4375 0.014 0.000 0.014 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:589(final) 58 0.000 0.000 0.010 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:374(star) 23333 0.006 0.000 0.009 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:38(__hash__) 1 0.000 0.000 0.009 0.009 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:728(parse_pattern) 1886/1 0.001 0.000 0.009 0.009 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:34(w) 1 0.000 0.000 0.009 0.009 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:484(parse) 1 0.000 0.000 0.009 0.009 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:57(parse) 1 0.000 0.000 0.009 0.009 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:490(start) 25/1 0.000 0.000 0.009 0.009 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:497(pattern) 34/1 0.000 0.000 0.008 0.008 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:503(conc) 467/14 0.000 0.000 0.008 0.001 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:512(obj) 24/9 0.000 0.000 0.008 0.001 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:517(group) 431 0.002 0.000 0.006 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:129(to_fsm) 443 0.001 0.000 0.005 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:585(atom) 6279 0.005 0.000 0.005 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:384(follow) 6341 0.004 0.000 0.004 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:86(static_b) 714 0.004 0.000 0.004 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:188(__init__) 397 0.001 0.000 0.003 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:116(_get_alphabet) 842 0.001 0.000 0.003 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:600(repetition) 558 0.002 0.000 0.003 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:93(__init__) 1753 0.001 0.000 0.002 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1515(__and__) 398 0.000 0.000 0.002 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:128(from_groups) 27275 0.002 0.000 0.002 0.000 {method 'items' of 'dict' objects} 104/103 0.000 0.000 0.002 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:80(prefix_postfix) 24825 0.002 0.000 0.002 0.000 {built-in method builtins.hash} 23333 0.001 0.000 0.001 0.000 {built-in method builtins.id} 34 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:333(_get_prefix_postfix) 1806 0.001 0.000 0.001 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:686(__call__) 151 0.001 0.000 0.001 0.000 {method 'union' of 'frozenset' objects} 151 0.001 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:119() 1352/866 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:91(lengths) 1 0.000 0.000 0.001 0.001 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:437(_get_prefix_postfix) 33/9 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:276(_get_lengths) 24/9 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:427(_get_lengths) 33/9 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:361(_get_lengths) 43 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:151() 398 0.000 0.000 0.001 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:130() 1806 0.001 0.000 0.001 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1093(__new__) 151 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:118() 1960 0.000 0.000 0.000 0.000 {built-in method builtins.hasattr} 1492 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1230(__hash__) 13 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:635(escaped) 446 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:118(any_but) 1219 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:113() 3996 0.000 0.000 0.000 0.000 {built-in method builtins.isinstance} 50 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:50(_combine_flags) 12 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:184(to_fsm) 25/1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:458(simplify) 59/2 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:466() 34/1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:410(simplify) 467/14 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:411() 33/9 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:298(simplify) 388 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:157() 161 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:402(final) 12 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:686(chargroup) 1272 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1502(__bool__) 912 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1535(__invert__) 26 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:95(anyof) 577 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:421(final) 9 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:827(copy) 12 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:196() 24 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:706(chargroup_inner) 9 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:150(copy) 1511 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects} 50 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1505(__or__) 151 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:122() 331 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:403() 26 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:131(multiple) 67 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:316() 13 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:102(anyof_b) 3 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1375(_missing_) 443 0.000 0.000 0.000 0.000 {method 'issubset' of 'set' objects} 50 0.000 0.000 0.000 0.000 :2(__init__) 55 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:90(__iter__) 397 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:126(_get_lengths) 36 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:78(static) 81 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:724() 397 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:168(simplify) 10 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1355(_iter_member_by_value_) 25 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:921() 6 0.000 0.000 0.000 0.000 :117(__instancecheck__) 25 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/fsm.py:919() 98 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:16(__init__) 18 0.000 0.000 0.000 0.000 {method 'copy' of 'dict' objects} 26 0.000 0.000 0.000 0.000 {method 'extend' of 'list' objects} 6 0.000 0.000 0.000 0.000 {built-in method _abc._abc_instancecheck} 1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:199(_get_alphabet) 55 0.000 0.000 0.000 0.000 {built-in method builtins.iter} 10 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:117(_iter_bits_lsb) 1 0.000 0.000 0.000 0.000 :121(__subclasscheck__) 1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:480(__init__) 78 0.000 0.000 0.000 0.000 {built-in method builtins.chr} 1 0.000 0.000 0.000 0.000 {built-in method _abc._abc_subclasscheck} 3 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:70(peek_static) 1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/utils/simple_parser.py:47(__init__) 18 0.000 0.000 0.000 0.000 {method 'copy' of 'frozenset' objects} 13 0.000 0.000 0.000 0.000 {method 'isalpha' of 'str' objects} 3 0.000 0.000 0.000 0.000 /nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/enum.py:1451() 3 0.000 0.000 0.000 0.000 {built-in method __new__ of type object at 0x7f25367b3bc0} 3 0.000 0.000 0.000 0.000 {method 'setdefault' of 'dict' objects} 12 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:211(simplify) 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} 3 0.000 0.000 0.000 0.000 {method 'join' of 'str' objects} 6 0.000 0.000 0.000 0.000 {built-in method builtins.ord} 1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:123(_get_prefix_postfix) 1 0.000 0.000 0.000 0.000 /home/andrew/p/outlines/.myenv/lib/python3.11/site-packages/interegular/patterns.py:208(_get_lengths) ```

For now I recommend trying to simplify the expression. However, in the long term - this isn't the first time this has come up, and should be addressed. I'll ponder how the crawl function (and the passed follows) could be made more efficient.

viktor-ferenczi commented 8 months ago

This is a blocker for me, so gave the optimization a try.

Thank you for the test case, that was a highly useful start.

In fsm.py of interegular :

def final(state):
    """If you're in a final state of the final FSM, it's final"""
    for (i, substate) in state:
        if i == last_index and substate in last.finals:
            return True
    return False

Since both state and last.finals are frozenset it can be optimized as:

def final(state):
    """If you're in a final state of the final FSM, it's final"""
    if len(state) < len(last.finals):
        for (i, substate) in state:
            if i == last_index and substate in last.finals:
                return True
    else:
        for final_substate in last.finals:
            if (last_index, final_substate) in state:
                return True 
    return False

But this barely makes a difference (<2%). That's because this code scales much worse:

In the crawl function there is this line:

j = states.index(next)

The number of items in states goes up way above 1000 even if I simplify my prompt down to 5 concatenated segments. It has O(N^2) time complexity, which is clearly not optimal for these kinds of patterns. Also, the code is using exception handling (ValueError) for logic, so it construct a useless Traceback whenever an index lookup fails.🤦‍♂️Also, states is not used at the end, only its length.

Attempted to optimize the crawl function by mapping state values to their indices, but the states mix up data types, specifically set, frozenset and dict.

At this point I gave up.

Will could use Lark grammar instead for my purposes, because interegular is badly written, but that's not available via the vLLM REST API yet (looked into outlines.serve.serve.generate).

I'm blocked.

viktor-ferenczi commented 8 months ago

There is also this expensive double-reversal in interegular.fsm:

    def reduce(self):
        """
            A result by Brzozowski (1963) shows that a minimal finite state machine
            equivalent to the original can be obtained by reversing the original
            twice.
        """
        return self.reversed().reversed()

Do we really that badly need a minimal finite state machine for our purposes?

However, it remains super slow even if I remove the above, so it would not help anyway.

lapp0 commented 8 months ago

I made crawl about 20x faster with

def crawl(alphabet, initial, final, follow):
    """
        Given the above conditions and instructions, crawl a new unknown FSM,
        mapping its states, final states and transitions. Return the new FSM.
        This is a pretty powerful procedure which could potentially go on
        forever if you supply an evil version of follow().
    """

    def get_hash(obj):
        if isinstance(obj, set):
            return hash(frozenset(obj))
        elif isinstance(obj, dict):
            return hash(tuple(sorted(obj.items())))
        return hash(obj)

    states = [initial]
    state_idx = {get_hash(initial): 0}
    finals = set()
    map = {}

    # iterate over a growing list
    i = 0
    while i < len(states):
        state = states[i]

        # add to finals
        if final(state):
            finals.add(i)

        # compute map for this state
        map[i] = {}
        for transition in alphabet.by_transition:
            try:
                next = follow(state, transition)
                next_hash = get_hash(next)
            except OblivionError:
                # Reached an oblivion state. Don't list it.
                continue
            else:
                try:
                    j = state_idx[next_hash]
                except KeyError:
                    j = len(states)
                    states.append(next)
                    if next_hash not in state_idx:
                        state_idx[next_hash] = j
                map[i][transition] = j

        i += 1

    return FSM(
        alphabet=alphabet,
        states=range(len(states)),
        initial=0,
        finals=finals,
        map=map,
        __no_validation__=True,
    )

The bottleneck is the reduce() function you referenced. If you implement Hopcroft's algorithm instead there might be two orders of magnitude of improvement.

image

https://www.cs.ru.nl/bachelors-theses/2017/Erin_van_der_Veen___4431200___The_Practical_Performance_of_Automata_Minimization_Algorithms.pdf

However, it remains super slow even if I remove the above, so it would not help anyway.

I can look into optimizing the subsequent steps as well.

viktor-ferenczi commented 8 months ago

Maybe we could use Python's compiled regex representation to come up with a better FSM ourselves without having to use interegular. Even if it would be limited it may work more stable / efficiently.

Python's regex

The compiler parses the regex into a kind of byte-code (see OPCODES), then the parser executes it.

viktor-ferenczi commented 8 months ago

I think a better approach would be to get the Lark grammar working with the vLLM endpoint (outlines.server.serve), so we can use that for the more complex cases instead of regex.

viktor-ferenczi commented 8 months ago

Please do not use exception handling for logic. Thanks.

                j = state_idx.get(next_hash)
                if j is None:
                    j = len(states)
                    states.append(next)
                    if next_hash not in state_idx:
                        state_idx[next_hash] = j
viktor-ferenczi commented 8 months ago

Thank you for the crawl optimization. I've just tried it, but that's not nearly enough to get my query working. Need to look for alternatives, best would be to finish adding support for the Lark grammar, so I can skip interegular with these kind of queries.

lapp0 commented 8 months ago

The crawl optimization mainly improves the performance of to_fsm(). For reduce() we need Hopcroft I think.

Hoping to get CFG for vLLM soon https://github.com/outlines-dev/outlines/pull/676

viktor-ferenczi commented 8 months ago

Trying to use this simpler regex in the meantime, but it allows the model to produce wrong output, which is not ideal:

(Path: `(.*?)`\n\n`{3}([a-z]+)\n(\n|[^`].*?\n)*`{3}\n\n)+(New: `(.*?)`\n\n`{3}([a-z]+)\n(\n|[^`].*?\n)*`{3}\n\n)*