tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
422 stars 54 forks source link

TopK op is causing some multidevice bad state #9206

Closed mtairum closed 3 months ago

mtairum commented 4 months ago

Issue Description

This behaviour was first spotted in T3000 CI tests. The issue persists after device close and seemingly can only be fixed with device reset.

I managed to triage the issue to the ttnn.experimental.operations.primary.topk.

When a module with the TopK op is executed on multi-device before Falcon40B decode test, the output PCC of falcon is greatly lowered (around 0.8 instead of 0.99).

This has been observed in multiple SJC T3000 machines.

Reproduction steps

  1. Run a small unit test of topK op. I have tested both the shapes used in our Mixtral MoE module below as well as the tests inside tests/tt_eager/python_api_testing/unit_testing/misc/test_topk.py:

test_topk.py

import torch
import ttnn

def test_topK(t3k_device_mesh, use_program_cache, reset_seeds):
    gate_shape = [1,1,32,64]
    gate_torch = torch.randn(gate_shape)
    gate_logits = ttnn.from_torch(
        gate_torch,
        dtype=ttnn.bfloat16,
        layout=ttnn.TILE_LAYOUT,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        device=t3k_device_mesh,
        mesh_mapper=ttnn.ReplicateTensorToMesh(t3k_device_mesh),
    )
    ttnn.experimental.operations.primary.topk(gate_logits, 32)
  1. Run falcon40b decode test. It will download the falcon weights if non-existent in your machine. pytest models/demos/t3000/falcon40b/tests/test_falcon_decoder.py::test_FalconDecoder_inference[BFLOAT8_B-SHARDED-falcon_40b-layer_0-decode_batch32-8chips]

The output of falcon40b test will be something like:

Output: Max ATOL Delta: 58.7618408203125, Max RTOL Delta: inf, PCC: 0.8449884518114068, PCC check failed
K Cache: Max ATOL Delta: 0.22574996948242188, Max RTOL Delta: inf, PCC: 0.9999446038705259
V Cache: Max ATOL Delta: 0.019847989082336426, Max RTOL Delta: inf, PCC: 0.9999691527555423
K Cache new token: Max ATOL Delta: 0.22574996948242188, Max RTOL Delta: inf, PCC: 0.9998440004560603
V Cache new token: Max ATOL Delta: 0.019847989082336426, Max RTOL Delta: inf, PCC: 0.9997933878705257

The correct output PCC should be instead PCC: 0.9984268799778373

mtairum commented 4 months ago

@sjameelTT You worked on the topK op, right? Would you be the right person to add to this ticket?

mtairum commented 4 months ago

Updated the description.

This can be replicated by running any of the TopK tests inside tests/tt_eager/python_api_testing/unit_testing/misc/test_topk.py and the falcon40B decode.

sjameelTT commented 4 months ago

Yeah I'm the right person for this. I'll take a look.

sjameelTT commented 4 months ago

Curious as to what you're getting rebased onto tip? I tried rerunning the tests on CI with

pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_topk.py::test_topk[1-1-32-64-32-BFLOAT8_B]

followed by:

pytest models/demos/t3000/falcon40b/tests/test_falcon_decoder.py::test_FalconDecoder_inference[BFLOAT8_B-SHARDED-falcon_40b-layer_0-decode_batch32-8chips]

and got 0.984 pcc, which is better than the 0.845 in the first post, though still causes a failure.

A big topk changed was merged in on Wednesday, perhaps that caused improvement. I'm waiting on getting access to IRD so I can't get access to my own T3000 atm and test, so I've just been using CI for now.

sjameelTT commented 3 months ago

Talked to @rdjogoTT and it seems like the topk_tile_init() puts the sfpu into a special mode needed for copy the swaps needed for sorting to the indices. The fix is to either make other ops init to the normal mode or just init to the normal mode at the end of topk.

sjameelTT commented 3 months ago

https://github.com/tenstorrent/tt-metal/actions/runs/9454419851/job/26042092028

This pipeline works and runs a temporary fix that puts back mixtral before falcon

uaydonat commented 3 months ago

Talked to @rdjogoTT and it seems like the topk_tile_init() puts the sfpu into a special mode needed for copy the swaps needed for sorting to the indices. The fix is to either make other ops init to the normal mode or just init to the normal mode at the end of topk.

The fix should be to put config back to default state, other ops should not know about what has run before them

mtairum commented 3 months ago

I just tried commit from @rdjogoTT d5312d7c1f997c59cbaeb12facb65c3866c026c1 isolated (cherry-picked on latest main) and PCC is good as well=0.998

Would there be a perf impact in choosing one of the solutions over the other? I.e. make other ops init to the normal mode or just init to the normal mode at the end of topk.

rdjogoTT commented 3 months ago

I just tried commit from @rdjogoTT d5312d7c1f997c59cbaeb12facb65c3866c026c1 isolated (cherry-picked on latest main) and PCC is good as well=0.998

Would there be a perf impact in choosing one of the solutions over the other? I.e. make other ops init to the normal mode or just init to the normal mode at the end of topk.

Perf impact should be small either way, but technically initializing back to normal mode after topk would add fewer new init calls. However, my commit implements the other solution as that is what we have in Buda, it just wasn't captured by the submodule migration and that's why it's missing from metal.

mtairum commented 3 months ago

Are there many ops that require a special sfpu config?

I understand that the solution is implemented is the good way of doing it, since it avoids potential issues with forgetting setting back the config on new ops. However this would add a small overhead in init calls to all ops, right?

We can measure e2e and device times with this change to understand how much impact this will have in large models with a big amount of ops, and if its noticeable or not.

I'll close this issue when the PR #9358 with your change gets merged into main 👍

sjameelTT commented 3 months ago

Is there an estimate for the cycle count when adding it to each init call?

rdjogoTT commented 3 months ago

The init for the SPFU config reg calls 3 single-cycle instructions, so 3 cycles. Would it be possible to measure e2e and device times to get some numbers to take a look at?

sjameelTT commented 3 months ago

Isn't the clock speed 1 GHz? Wouldn't 3 cycles be extremely negligible even if it's across all inits?

rdjogoTT commented 3 months ago

Yeah I think it should be very small, but I don't know how much it can add up in a model with many ops.

sjameelTT commented 3 months ago

https://github.com/tenstorrent/tt-metal/actions/runs/9467796636

CI run with radomir's change and moving mixtral back before falcon

mtairum commented 3 months ago

I run some perf tests of a full Mixtral iteration with 32L and 2048KV len and the this change doesn't affect timings.

Old is the current main, new is with the added sfpu config inits.

Device only

E2E

Regarding e2e I've seen some variability (seen between 145ms ~ 168ms) between either version so any impact this change might have is negligible at best.

I think Radomir's change is good enough as is, especially if it makes it better at avoiding future issues when introducing new ops.

rdjogoTT commented 3 months ago

Great, in that case I'll merge my PR into main as soon as it passes post-commit.

rdjogoTT commented 3 months ago

https://github.com/tenstorrent/tt-metal/pull/9358 merged