ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.3k stars 898 forks source link

Getting TypeError: 'module' object is not subscriptable on generations with prompts over a certain size #1022

Closed chimezie closed 1 month ago

chimezie commented 1 month ago

I'm getting the following traceback generating from models using the latest mlx/mlx-examples from git:

Traceback (most recent call last):
  [..snip..]
  File "/path/to/mlx-examples/llms/mlx_lm/models/base.py", line 39, in create_attention_mask
    if cache is not None and cache[0] is not None:
                             ~~~~~^^^
TypeError: 'module' object is not subscriptable

The script below and its invocations demonstrate the error on a random set of text of a given range of sizes, one where most of the generations produce the error and the other where none do. The main distinction between when the error occurs seems to just be the size of the input and not the model ( have tried others)

#!/usr/bin/env python
import click
def get_random_words(lowest_num_words, highest_num_words):
    import requests
    import random
    url = "https://www.mit.edu/~ecprice/wordlist.10000"
    response = requests.get(url)
    wordlist = response.text.splitlines()
    while True:
        num_words = random.randint(lowest_num_words, highest_num_words)
        selected_words = random.sample(wordlist, num_words)
        yield selected_words

@click.command()
@click.option('-l', '--lowest-num-words', default=10, type=int)
@click.option('-u', '--highest-num-words', default=100, type=int)
@click.option('-n', '--num-generations', default=10, type=int)
@click.argument('model')
def main(lowest_num_words, highest_num_words, num_generations, model):
    from mlx_lm.utils import load, generate
    model, tokenizer = load(model)
    errors = {}
    for _, words in zip(range(num_generations), get_random_words(lowest_num_words, highest_num_words)):
        text = ' '.join(words)
        num_words = len(words)
        num_tokens = len(tokenizer.encode(text))
        print(f"{num_words :,} words, {num_tokens :,} tokens")
        try:
            generate(model, tokenizer, f"Summarize the text below: \n{text}")
        except TypeError as e:
            import traceback
            errors.setdefault(num_tokens, set()).add(traceback.format_exc())
    print(f"Errors in {len(errors):,}/{num_generations:,} . Token lengths: {list(errors)}")
    for error_trace in set().union(*errors.values()):
        print(error_trace)
if __name__ == "__main__":
    main()

Small run

% ./snippets.py -l 100 -u 400 raw_models/mlx/Gemma-2-9B-It-SPPO-Iter3
320 words, 361 tokens
284 words, 310 tokens
267 words, 292 tokens
143 words, 157 tokens
281 words, 302 tokens
320 words, 351 tokens
240 words, 262 tokens
236 words, 255 tokens
115 words, 122 tokens
120 words, 132 tokens
Errors in 0/10 . Token lengths: []

Slightly larger run

% ./snippets.py -l 400 -u 800 raw_models/mlx/Gemma-2-9B-It-SPPO-Iter3
478 words, 536 tokens
747 words, 815 tokens
619 words, 681 tokens
526 words, 573 tokens
613 words, 668 tokens
783 words, 848 tokens
429 words, 462 tokens
426 words, 476 tokens
532 words, 578 tokens
472 words, 509 tokens
Errors in 8/10 . Token lengths: [536, 815, 681, 573, 668, 848, 578, 509]
Traceback (most recent call last):
  File "/path/to/./snippets.py", line 29, in main
    generate(model, tokenizer, f"Summarize the text below: \n{text}")
  File "/path/to/mlx-examples/llms/mlx_lm/utils.py", line 335, in generate
    for n, (token, logprobs) in zip(
  File "/path/to/mlx-examples/llms/mlx_lm/utils.py", line 242, in generate_step
    model(y[:prefill_step_size][None], cache=cache)
  File "/path/to/mlx-examples/llms/mlx_lm/models/gemma2.py", line 192, in __call__
    out = self.model(inputs, cache)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/mlx-examples/llms/mlx_lm/models/gemma2.py", line 168, in __call__
    mask = create_attention_mask(h, cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/mlx-examples/llms/mlx_lm/models/base.py", line 39, in create_attention_mask
    if cache is not None and cache[0] is not None:
                             ~~~~~^^^
TypeError: 'module' object is not subscriptable