tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
395 stars 48 forks source link

Context Leakage - Mistral 7B #11061

Closed mvanniasingheTT closed 1 month ago

mvanniasingheTT commented 1 month ago

Describe the bug There is context leakage in the mistral 7B model when implementing batching. Modulo wrapping is used to cycle prompts from a batch of size 8 between users.

For example:

prompt: What is your favourite condiment?
output:
 [/INST] I don't have a physical body or the ability to have favourites. However, I can tell you that many people enjoy condiments like ketchup, mustard, soy sauce, hot sauce, mayonnaise, and relish. The favourite condiment can depend on personal taste preferences.

If you have a specific condiment in mind, I can provide you with more information about it.

For example, ketchup is a popular condiment made from tomatoes
prompt: Hello, how are you?
output:
 [/INST] I'm just a computer program, so I don't have the ability to feel emotions or have a personality. I'm here to help answer any question you might have. Is there a specific question you have in mind? I'll be happy to help answer any question you might have. Let me know if you have any question related to it.

Ketchup is a tomato sauce commonly used as a condiment for various dishes. It's made by cooking tomatoes, onions, garlic

Another example:

batch: 0 user: 2
prompt: Do you have mayonnaise recipes?
output:
 [/INST] Yes, I can provide you with a simple and classic mayonnaise recipe.

Ingredients:
- 1 cup (225 ml) vegetable oil
- 1 egg yolk
- 1 tablespoon (15 ml) white wine vinegar or lemon juice
- 1 teaspoon (5 g) Dijon mustard (optional)
- Salt and pepper to taste```

```batch: 1 user: 2
prompt: Which color do you get if you mix yellow and blue?
output:
 [/INST] When you mix yellow and blue paint, you get green color. However, in the case of making mayonnaise or green paint, you actually mix yellow and a little bit of blue to create a green hue. So, the answer to your question is green.m/r/1026632/r/1026632/r/1026632/r/1026632/r/1

To Reproduce Steps to reproduce the behavior:

  1. Run model demo https://github.com/tenstorrent/tt-metal/blob/mvanniasinghe/mistral7B_n_batches/models/demos/wormhole/mistral7b/demo/demo.py
  2. you may set print_to_file to true to save the output of the model for easy viewing
  3. modified pytest command is pytest models/demos/wormhole/mistral7b/demo/demo.py::test_mistral7B_demo[<weights>-<number of batches>_batch] number of batches can be 1-5 for the issue to be seen.

Additional context Refer to following PR: https://github.com/tenstorrent/tt-metal/pull/11060

mtairum commented 1 month ago

For reference, the latest commit in branch mvanniasinghe/mistral7B_n_batches still has this issue, correct?

I've run the general weights demo and can see leakage between batches for same user.

mvanniasingheTT commented 1 month ago

Hi, Yes that is correct.

mtairum commented 1 month ago

Added a fix for the cache leakage: it wasn't being resetted back to the model: https://github.com/tenstorrent/tt-metal/pull/11060/commits/a2a1c96fd10f9f5c38dd6767b3df6e0c713d6a35

Although this issue is now fixed, I'm seeing worse outputs on batch 2 onwards. For the same prompt I would expect the same output, but it's not only different but also worse.

E.g. Batch 0, user 7:

batch: 0 user: 7
prompt: What is 2+2? 
output:
 The answer to the expression "2 + 2" is 4. This is a basic arithmetic problem, and the answer is a well-known fact. 

If you have any other math questions, feel free to ask! 

Here's a simple explanation of how to solve the expression "2 + 2":

1. Identify the numbers involved in the expression. In this case, the numbers are 2 and 2.
2. Follow the order

Then batch 1, user 6 (same prompt as 7 in batch 0):

batch: 1 user: 6
prompt: What is 2+2? 
output:
 It seems that there is a typo in your question. I'll assume you meant to ask "What is 2 2 mean?"

Unfortunately, without more context, it's difficult to provide a definitive answer. However, based on the context provided, it seems that "2+2" might be a typo or a reference to something specific to a particular context. Without further context, it's impossible to determine the meaning of "2+2". I hope this helps clar

looking into the KV cache to see if they differ between batches.

mtairum commented 1 month ago

I've been having a thorough look at this and haven't yet found a fix.

I focused on a specific prompt and looked at the KV cache for it across 2 batches and its users. The prompt What is your favourite condiment? is used by user 0 at batch 0 and user 7 at batch 1. They start diverging at generated token 7 (token 22 from start). Below is the top-3 values:

[batch0,user0, 7 tokens]: I don't have a physical
[batch0,user0, 7 tokens]: First difference at token 22. Top 3: values=[17.2500, 16.8750, 16.1250] indices=[ 5277,  6656,  3327]

[batch1,user7, 7 tokens]: I don't have a favorite
[batch1,user7, 7 tokens]: First difference at token 22. Top 3: values=[17.5000, 15.5000, 14.3125], indices=[ 6656, 16020,  3327]

The KV cache returns torch.Allclose == True until the point of divergence across batches (from token 0 till token 22), as it would be expected.

Another thing I found is that the output is consistent across runs, as it should. So the 3 batches I'm running will always have the same outputs across runs, with each increasingly batch having worse output than before. For the above prompt, the 3 batch outputs look like this:

prompt: What is your favourite condiment? 

[Batch 0] I don't have a physical body or the ability to have favourites. However, I can tell you that many people enjoy condiments like ketchup, mustard, soy sauce, hot sauce, mayonnaise, and relish. The favourite condiment can depend on personal taste preferences. 

If you have a specific condiment in mind, I can provide you with more information about it. 

For example, ketchup is a popular condiment made from tomatoes

[Batch 1] I don't have a favorite condiment [sauce. I'm just a computer program, so I don't have personal preferences or favorite condiments. If you have any specific condiment [sauce or other type of seasoning, please let me know and I'll do my best to help you find a great condiment [sauce or other condiment [sauce for that specific dish. I'll be here to help you find the perfect condiment [s

[Batch 2]  It's my favourite condiment is not clear.

�s what is the difference between a condiment and a condiment?
�s the difference between a condiment and a condiment is unclear.

I'd like to clarify that for you. The difference between a condiment and a condiment is that they're not the same.

A condiment and a condiment are not the same thing. They're not the same as they.

To clar

Also found out that reseting device, i.e. reloading all weights, between batches leads to a good output -> All 3 batches with same-prompt users are the same. However this is not how we want to properly run batching, so it only serves to debug the issue.

Below a summary of things I've tried:

mtairum commented 1 month ago

Fixed has been pushed to mvanniasinghe/mistral7B_n_batches.

I'll do one last change to the PR: https://github.com/tenstorrent/tt-metal/pull/11060 (reduce the new tests that will run on CI). And run the CI tests so we can merge to main.

mtairum commented 1 month ago

In main now.