tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
382 stars 47 forks source link

Multiple models: low PCC when program cache is enabled #7159

Closed mikevin920 closed 3 months ago

mikevin920 commented 5 months ago

Describe the bug When running Llama2 70b decode and prefill model test, the PCC is good ~0.999 for a single layer when program cache is turned off, but PCC because very low when we enable program cache by adding use_program_cache in the test fixture.

Decode mode average PCC for 1 Layer over 20 tokens without program cache: 0.999497 Decode mode PCC for 1 Layer over 20 tokens with program cache: 0.9759 Prefill mode 128 seq_len PCC for 1 Layer without program cache: 0.99936 Prefill mode 128 seq_len PCC for 1 Layer with program cache: unable to repro due to a bug with using tt_lib.tensor.sharded_to_interleaved_partial with program cache

To Reproduce

  1. Build tt-metal with the latest main on a T3000
  2. run pytest models/demos/llama2_70b/tests/test_llama_model.py::test_LlamaModel_inference[decode-8chip-T3000-1L]
  3. add use_program_cache in the test fixture in Line 269 of test_llama_model.py
  4. run the command again
  5. Observe the difference in PCC
johanna-rock-tt commented 4 months ago

We are seeing the same issue with Falcon40b prefill. We see the issue with all sequence lengths, but with longer sequence lengths the PCC difference between using and not using the program cache gets significantly larger. Demo also produces bad outputs when used with use_program_cache.

Falcon40b prefill S=2048, end_to_end test with 1 layer: PCC without program cache enabled: 0.99 PCC with program cache enabled: 0.49

johanna-rock-tt commented 4 months ago

FYI: @jliangTT @s-jovic @miacim

cglagovichTT commented 4 months ago

Assigned to @jliangTT for him to reassign.

cglagovichTT commented 4 months ago

@TT-BrianLiu sounds like this is affecting Falcon40B prefill. Do you have ideas on how to debug?

TT-BrianLiu commented 4 months ago

I am waiting for other issues (like this one: #6775) to be fixed first.

jliangTT commented 4 months ago

yes we are still working on other P0 issues and will need to come back to this one.

pavlepopovic commented 4 months ago

https://github.com/tenstorrent/tt-metal/issues/7704 ShardedToInterleavedPartial and InterleavedToShardedPartial have a bug with program caching enabled. Fixing that fixes Falcon40B program cache issues that previously existed, so it's worth trying LLAMA once the fix goes in. You can also try this before it goes into main: https://github.com/tenstorrent/tt-metal/pull/7750

kevinmiTT11 commented 4 months ago

can confirm this commit solves pcc issues for llama prefill

cglagovichTT commented 4 months ago

Note that we're still seeing poor PCC for llama2-70b in decode mode with program cache enabled.

eyonland commented 4 months ago

Every time there is a failure on PCC because program cache is enabled a good place to look is whether the op has implemented the compute_program_hash method. If they have, the most likely scenario is that the op that resulted in a bad PCC is a result of not having a good hash. If they have not, then the issue could quite possibly be because both operations sharded_to_interleaved and interleaved_to_sharded, could share the same hash function and be using the same same typeid.

auto operation_type_hash = typeid(OperationType).hash_code()

If you have a bad hash, you end up re-using the wrong cached code.
@cglagovichTT , do you know what op you see that is causing the bad pcc in the decode mode?

kevinmiTT11 commented 4 months ago

@eyonland most likely its the sharded_to_interleaved and interleaved_to_sharded? since our llama_decoder test in decode mode has good PCC with program cache enabled and it contains every OP present in llama_model. Is there any way to not let these two ops have the same hash?

eyonland commented 4 months ago

I just looked back over the code for sharded_to_interleaved and interleaved_to_sharded and I don't think they will end up sharing the same hash but it would be good to check with the visualizer and look at those hashes to sanity check.

kevinmiTT11 commented 4 months ago

Used ttnn visualizer to check, all program hash looks good, same hash as test_decoder tests which has good PCC, llama model test for 1 layer is has bad PCC in the second run deterministically even when commenting out lm head and rms norm.

jliangTT commented 4 months ago

are you able to use this information to produce a repro with one op w/ good program hash?

kevinmiTT11 commented 4 months ago

@jliangTT @eyonland could you elaborate on what the next steps could be for this? Not really sure how to repro if all hash are the same but only in that test file we see a discrepancy of bad pcc.

kevinmiTT11 commented 4 months ago

Found a workaround, by adding del tt_inp_emb, rot_mat, attn_mask after the model.foward() in test_llama_model, we were able to get good PCC constantly with program cache enabled. Seems like the model is using the previous inputs or a data corrupted input for the second iteration if we don't explicitly removes the reference pointer to the input of the last iteration. Calling tensor.deallocate() for the inputs on device doesn't solve the issue, only calling delwould work or assigning a value of None. This behaviour only shows up when program cache is enabled, when pc is disabled we do not need to explicitly use del.

jliangTT commented 4 months ago

Just a note from talking to @uaydonat : this workaround may also apply to falcon40B prefill. The same workaround can be tried whenever it makes sense.

johanna-rock-tt commented 4 months ago

Thanks @jliangTT. Our test from above is now also passing with good pcc. It's a bit weird though, because we have had deletes for all our model outputs also before and I didn't change anything now expect for enabling the program cache and re-run the test (tip of main as of May 6th). Maybe some other issue got fixed in the meantime too? In any case, enabling program cache works now for all tests I've tried so far.

uaydonat commented 4 months ago

@arakhmati there is some kinda data corruption (see the workaround 3 post above). Is this something you can look into?

arakhmati commented 4 months ago

@uaydonat can you guys re-run your with the following env variable:

OPERATION_HISTORY_CSV=history.csv pytest ...

And then you can open history.csv, and you will see a column for program_hash

Then you can check if there is unexpected hash collisions. We don't have a script for that but one could be written

Hopefully, that will help you spot an issue. If not, let's ping @jliangTT and assign this to someone

cglagovichTT commented 4 months ago

Kevin did this and saw no issues in the program hashes. He compared our passing test (decoder) with failing test (model) and the diff in hashes looked good.

He happened to find that deleting certain outputs after each iteration resolved the issue.

mtairum commented 3 months ago

Deleting some inputs that we send to model (namely input and attn mask) helped us deal with PCC issues on mixtral with program cache enabled. (issue: #8902 )

yieldthought commented 3 months ago

This is also affecting Mixtral, we are applying the same workaround for now

uaydonat commented 3 months ago

@eyonland this issue keeps popping. Can you help us?

davorchap commented 3 months ago

Deleting some inputs that we send to model (namely input and attn mask) helped us deal with PCC issues on mixtral with program cache enabled. (issue: #8902 )

There's a possibility that one of the Mixtral OPs is not implementing the Program cache (hash function) correctly ?

uaydonat commented 3 months ago

llama and falcon40 prefill are also suffering from this.

A preliminary analysis of the hashes was reported above:

Kevin did this and saw no issues in the program hashes. He compared our passing test (decoder) with failing test (model) and the diff in hashes looked good.

He happened to find that deleting certain outputs after each iteration resolved the issue.

tt-aho commented 3 months ago

Can we summarize which models still see these issues and which specific tests to run (and if we need to change anything due to workarounds put in)?

I see in some linked issues there were op bugs that were found / fixed so not sure what the status is for all these models.

cglagovichTT commented 3 months ago

@tt-aho I can speak for Llama, which still has this issue. Repro steps are in the issue body, pasting here.

Build tt-metal with the latest main on a T3000 run pytest models/demos/llama2_70b/tests/test_llama_model.py::test_LlamaModel_inference[decode-8chip-T3000-1L] add use_program_cache in the test fixture in Line 269 of test_llama_model.py run the command again Observe the difference in PCC

tt-aho commented 3 months ago

I have a fix for llama and potentially other models. Address was not getting updated for softmax mask if it was sharded. Fix is in this pr https://github.com/tenstorrent/tt-metal/pull/9212.

I verified that the specified llama test has same pcc now with/without program cache. Please retest your models with the change from that branch (or wait until it is merged) and update this issue/file a new issue if your model still has pcc issues with program cache.

tt-aho commented 3 months ago

Fix is merged to main. Note I did not enable any models/tests to use program cache after this fix. Please verify and update accordingly. I will close this issue since the original problem should be resolved. Please file a new issue or reopen if there are still pcc bugs when enabling program cache.

uaydonat commented 3 months ago

@johanna-rock-tt and @yieldthought can you remove the workaround and try the fix to confirm the issue is resolved?

johanna-rock-tt commented 3 months ago

Falcon40b end to end test (the one where we observed this issue) passes after rebasing to main and removing the workarounds!

mtairum commented 3 months ago

Mixtral8x7B is also passing without the explicit tensor deletes 👍