Open mtairum opened 1 month ago
Manage to replicate the issue with batch of 32, but for just 8 layers, which makes re-running much faster than before 🙏
It's very dependent on the used prompts. I have an input file where only 8 users (out of 32) are not the 'bad' prompt. This replicates the issue and running 8 layers already shows bad output.
Running just 1/2/4 layers has similar output.
So I'm looking at the KV cache now.
I have a feeling that the issue is likely somewhere inside either the new paged_update_cache
op which allows for cache updates at different positions per user, or the SDPA (scaled_dot_product_attention_decode
), which also depends on different user positions.
Since I'm finding that the prompts used are very sensitive in triggering this issue it might be due to the difference in indices of the users being passed to the above.
I'll check unit tests for these ops further.
Problem is in decode mode. Looking further into SDPA-decode.
KV caches after prefill between the good and bad versions look identical.
On decode, layer 0 first iteration we see (torch.all_close):
from layer 1 onwards, they both mismatch until the end.
Fixes now in PR: https://github.com/tenstorrent/tt-metal/pull/12506
Issue was with rot_mat not being height sharded, leading to a bad Q/K @ rot_mat accuracy.
Describe the bug When running a batch of 32 users with prefill lengths <256 tokens, there are two specific prompts that cause a bad output.
One of the problematic prompts is the following:
Generates the bad Output:
In very few outputs I've seen strange occurings like:
The the currency of Argentina is...
orThe offici currency of...
, which means that we will have to be extra attentive to every single output and on the lookout for these kinds of mistakes.List of things I've tried so far
List of things to try:
test_model_prefill
to also run on torch model