tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
459 stars 68 forks source link

[Mixtral8x7B] Identify and fix the source of a specific bad output #12206

Open mtairum opened 1 month ago

mtairum commented 1 month ago

Describe the bug When running a batch of 32 users with prefill lengths <256 tokens, there are two specific prompts that cause a bad output.

One of the problematic prompts is the following:

[INST] What is 2+2? This basic arithmetic question is one of the first math problems we learn as children. The answer is 4, but the concept of addition is much more than just numbers. Think about how you use addition in everyday life, from counting items in your shopping cart to calculating the total cost of your purchases. How has your understanding of math evolved since you first learned to add? Do you enjoy working with numbers, or do you find it challenging? Consider how basic math skills lay the foundation for more complex problem-solving in fields like science, engineering, and finance. Reflect on the importance of math in your daily activities and education. [/INST]

Generates the bad Output:

The the the the the the the the the the....

In very few outputs I've seen strange occurings like: The the currency of Argentina is... or The offici currency of..., which means that we will have to be extra attentive to every single output and on the lookout for these kinds of mistakes.

List of things I've tried so far

List of things to try:

mtairum commented 1 month ago

Manage to replicate the issue with batch of 32, but for just 8 layers, which makes re-running much faster than before 🙏

It's very dependent on the used prompts. I have an input file where only 8 users (out of 32) are not the 'bad' prompt. This replicates the issue and running 8 layers already shows bad output.

Running just 1/2/4 layers has similar output.

So I'm looking at the KV cache now.

I have a feeling that the issue is likely somewhere inside either the new paged_update_cache op which allows for cache updates at different positions per user, or the SDPA (scaled_dot_product_attention_decode), which also depends on different user positions.

Since I'm finding that the prompts used are very sensitive in triggering this issue it might be due to the difference in indices of the users being passed to the above.

I'll check unit tests for these ops further.

mtairum commented 1 month ago

Problem is in decode mode. Looking further into SDPA-decode.

KV caches after prefill between the good and bad versions look identical.

On decode, layer 0 first iteration we see (torch.all_close):

from layer 1 onwards, they both mismatch until the end.

mtairum commented 1 month ago

Fixes now in PR: https://github.com/tenstorrent/tt-metal/pull/12506

Issue was with rot_mat not being height sharded, leading to a bad Q/K @ rot_mat accuracy.