ml-explore / mlx-swift-examples

Examples using MLX Swift
MIT License
628 stars 85 forks source link

Crash on `concatenate` after latest update #114

Open DePasqualeOrg opened 3 weeks ago

DePasqualeOrg commented 3 weeks ago

After https://github.com/ml-explore/mlx-swift-examples/commit/ab94ffc2f31a70ead3c7007afaf97a225ed3ec90, I'm getting a crash the second time I try to generate text with my app, which uses mlx-libraries. I can't reproduce this with the LLMEval example app at the moment, but I'll try to find the cause.

MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,8,1,128), (1,8,512,128), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/<app name>-ejjtjaklhfhyarhbwjdbxiatlsar/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217

awni commented 3 weeks ago

Which model were you running?

DePasqualeOrg commented 3 weeks ago

So far it has happened on Phi 3.5 mini 4-bit and Llama 3.1 9B 4-bit. I haven't tested other models yet.

awni commented 3 weeks ago

Can you provide more details on what exactly you are running? Most likely the line it's breaking at is in the new KV cache. It looks like one of the inputs to that function has the wrong order.

Are you using custom model code or the same models as in the example?

DePasqualeOrg commented 3 weeks ago

I'm using the models from mlx-libraries. Generally this happens on the second or third prompt in a conversation. I'm still trying to investigate this on my end but wanted to open this issue in case others are having similar problems.

davidkoski commented 3 weeks ago

I was using mlx-community/Phi-3-mini-4k-instruct-4bit as the primary use case so I know that one works generally.

Is there something I can do to reproduce the issue? I am happy to debug it.

davidkoski commented 3 weeks ago

Generally this happens on the second or third prompt in a conversation

How do you do the conversation? How is the state carried from one call to generate to the next?

DePasqualeOrg commented 3 weeks ago

I use the prompt templates that are commonly used for each model to represent the conversation, adding to them for each new prompt and response, and passing the updated template to generate when a new prompt is submitted. I've had to build this myself, since swift-transformers doesn't include it, although this may change soon. I'll post an update here when I can reproduce this with LLMEval.

DePasqualeOrg commented 3 weeks ago

I've been able to reproduce this with LLMEval after updating the swift-transformers dependency to use the latest commit on the main branch. Short prompts work as expected, but after submitting a long prompt (about 5600 characters) with Llama 3 9B 4-bit, I get this crash:

MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,8,170,128). at /Users/<user>/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eiqkagimbcumwufwrjncqseqpfjo/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61

And with Phi 3.5 mini:

MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,32,2,96), (1,32,512,96), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eiqkagimbcumwufwrjncqseqpfjo/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217

I noticed that the active memory indicator in the app grows to a very large number when this happens.

davidkoski commented 3 weeks ago

ok perfect, I have a repro using the output of the previous runs:

p2.txt

davidkoski commented 3 weeks ago

actually a different error:

<|assistant|>MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,32,182,96). at /Users/dkoski/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eimbjcofifunwybkcvhnzjbqwyri/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61

hopefully that is related though. I also see the big memory spike

(actually this matches your other error)

davidkoski commented 3 weeks ago

So something is off in the KVCache. The shape of the keys after the prompt on the python side:

(1, 32, 512, 96)

swift:

- 0 : 256
- 1 : 32
- 2 : 256
- 3 : 96

The 256 vs 512 is because I messed with the prefill step size. Anyway the 0 dimension is not right.

davidkoski commented 3 weeks ago

This comes from different shapes as input to Attention:

python: (1, 512, 3072) swift: [256, 1, 3072]

aha:

        model(y[:prefill_step_size][None], cache=cache)

does not translate to:

            _ = model(y[..<parameters.prefillStepSize, .newAxis], cache: cache)

the order of the trailing [None] is actually first:

y[.newAxis, ..<parameters.prefillStepSize]
davidkoski commented 3 weeks ago

It is amazing that this gross mismatch of shapes ... mostly works. It sure would be nice to have some typing on shapes. I suppose we could use precondition

davidkoski commented 3 weeks ago

@DePasqualeOrg thank you so much for reporting this! Your info got a quick repro and I was able to track down the issue. You can try kvcache2 from #115

DePasqualeOrg commented 3 weeks ago

Fantastic, thank you! I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like assistant<|end_header_id|> from Llama 3.1 9B. I guess this is due to the new KV cache?

I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach.

awni commented 3 weeks ago

I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like assistant<|end_header_id|> from Llama 3.1 9B

The handling of the RoPE positional encodings is not quite right for both Llama 3.1 and Phi 3.5. So if you're prompt + generation is very long (like 4k tokens or more) that might explain it. The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.

I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach.

Since the attention steps are fixed at 512 the maximum size of the attention scores is now 512 * 512 * num_heads * 2 which is not that big. The memory bottleneck for long prompts will most likely be the memory used by the KV cache. That will scale as the product of the following factors:

DePasqualeOrg commented 4 days ago

The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.

Whenever you're able to update to the latest MLX, I'll test this again and see if that solves the problem.