tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 74 forks source link

Large DRAM-sharded matmuls reliably hang on wormhole #13668

Open yieldthought opened 1 month ago

yieldthought commented 1 month ago

Describe the bug Large DRAM-sharded matmuls cause wormhole to hang after a few iterations. Reducing the number of columns in the weight matrix to below 32k seems to work around the issue. Sometimes the board seems to need to be reset twice before it can be used again.

To Reproduce Steps to reproduce the behavior:

  1. Build metal 2.pytest dram_sharded_hang.py

Expected behavior Progress bar reaches 100 and pytest passes

Screenshots

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | MMIO Device 0 : Tunnel 0 : Device 0
                  Metal | INFO     | MMIO Device 0 : Tunnel 0 : Device 4
                  Metal | INFO     | Enabling program cache on device 0
  5%|█▌                          | 5/100 [00:01<00:17,  5.58it/s]

Please complete the following environment information:

Additional context Doesn't happen on every iteration, Jasmina suggested we file as a suspected di/dt issue but could also be a race in DRAM reading.

yieldthought commented 1 month ago

Not sure why GitHub won't allow me to attach a python file, here is dram_sharded_hang.py:

import math
import torch
import ttnn
import pytest
from tqdm import tqdm

@pytest.mark.timeout(300)
def test_dram_sharded_matmul(device, use_program_cache, reset_seeds):
    # Model configuration
    dim = 4096
    vocab_size = 128256
    split_size = vocab_size // 2

    # Create dummy input
    batch_size = 1
    seq_len = 1
    x = torch.randn(batch_size, seq_len, dim)

    # Create dummy weights and split them
    output_weight = torch.randn(vocab_size, dim)
    output_weight_1 = output_weight[:split_size]
    output_weight_2 = output_weight[split_size:]

    # Perform PyTorch matmul for comparison
    reference_output = torch.matmul(x, output_weight.t())

    # Configure memory layout for output_weight (for both parts)
    def create_output_mem_config(size):
        # Calculate padded size to ensure it's divisible by (32 * 12)
        padded_size = math.ceil(size / (32 * 12)) * (32 * 12)
        if padded_size != size:
            print(f"Original size: {size}, Padded size: {padded_size}")
        shard_spec = ttnn.ShardSpec(
            ttnn.CoreRangeSet({
                ttnn.CoreRange(
                    ttnn.CoreCoord(0, 0),
                    ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1)
                )
            }),
            (4096, padded_size // 12),
            ttnn.ShardOrientation.ROW_MAJOR,
            False
        )
        return ttnn.MemoryConfig(
            ttnn.TensorMemoryLayout.WIDTH_SHARDED,
            ttnn.BufferType.DRAM,
            shard_spec
        )

    # Convert output_weight parts to ttnn tensors
    output_weight_ttnn_1 = ttnn.as_tensor(
        output_weight_1.permute(1, 0),
        device=device,
        memory_config=create_output_mem_config(split_size),
        layout=ttnn.TILE_LAYOUT,
        dtype=ttnn.bfloat8_b
    )
    output_weight_ttnn_2 = ttnn.as_tensor(
        output_weight_2.permute(1, 0),
        device=device,
        memory_config=create_output_mem_config(split_size),
        layout=ttnn.TILE_LAYOUT,
        dtype=ttnn.bfloat8_b
    )

    # Convert input to ttnn tensor
    x_ttnn = ttnn.from_torch(
        x,
        device=device,
        dtype=ttnn.bfloat16,
        layout=ttnn.TILE_LAYOUT,
        memory_config=ttnn.create_sharded_memory_config(
                (32, 4096 // 64),  # Shard shape: [32, 64] -> 1 shard per core
                ttnn.CoreGrid(y=8, x=8),
                ttnn.ShardStrategy.WIDTH,
                ttnn.ShardOrientation.ROW_MAJOR,
                use_height_and_width_as_shard_shape=True,
            ),
    )

    # Configure compute kernel
    compute_kernel_config = ttnn.WormholeComputeKernelConfig(
        math_fidelity=ttnn.MathFidelity.HiFi2,
        math_approx_mode=False,
        fp32_dest_acc_en=False,
        packer_l1_acc=True,
    )

    # Configure program
    program_config = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
        in0_block_w=1,
        per_core_M=1,
        per_core_N=32, # 128256 / 2 / tile_size / core_count
        fused_activation=None,
    )

    for i in tqdm(range(100)):
        # Run the linear layers
        output_1 = ttnn.linear(
            x_ttnn,
            output_weight_ttnn_1,
            compute_kernel_config=compute_kernel_config,
            program_config=program_config,
            memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
            dtype=ttnn.bfloat8_b,
        )
        output_2 = ttnn.linear(
            x_ttnn,
            output_weight_ttnn_2,
            compute_kernel_config=compute_kernel_config,
            program_config=program_config,
            memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
            dtype=ttnn.bfloat8_b,
        )

        output_1 = ttnn.sharded_to_interleaved(output_1)
        output_2 = ttnn.sharded_to_interleaved(output_2)

        # Concatenate the outputs
        output = ttnn.concat([output_1, output_2], dim=-1)

        # Convert output back to PyTorch tensor
        ttnn_output_torch = ttnn.to_torch(output)

        # Assertions
        assert ttnn_output_torch.shape == (batch_size, seq_len, vocab_size), f"Expected output shape {(batch_size, seq_len, vocab_size)}, but got {ttnn_output_torch.shape}"
        assert not torch.isnan(ttnn_output_torch).any(), "Output contains NaN values"
        assert not torch.isinf(ttnn_output_torch).any(), "Output contains infinite values"

    print("Output shape:", ttnn_output_torch.shape)
    print("TTNN output sample:", ttnn_output_torch[0, 0, :10].tolist())  # Print first 10 elements of the TTNN output
    print("Reference output sample:", reference_output[0, 0, :10].tolist())  # Print first 10 elements of the reference output

    # Compare TTNN output with PyTorch matmul
    pcc = ttnn.pearson_correlation_coefficient(ttnn_output_torch.flatten(), reference_output.flatten())
yieldthought commented 1 month ago

For context: this is Llama 3's LM head matmul. We have to split it into two because otherwise there's not enough L1 on the 12 DRAM-sharded cores to handle columns this large.

pavlepopovic commented 3 weeks ago

Tested this one on 7 machines in total (mix of N300 and N150), here are the results:

Removing matmul_block() call from the compute kernel did NOT make the issue go away, indicating that this is not a di/dt issue that we know of (cores starting mm computation at the same time).

Removing semaphore calls from in0_sender kernel did not make the issue go away.

Removing noc_async_read_tile_dram_sharded_with_state_with_trid() calls from in1_sender kernel did make the issue go away, on all machines. So I believe this is either a problem with that function, or an underlying noc/dram issue. FYI @ttmtrajkovic @yugaoTT @davorchap

cglagovichTT commented 3 weeks ago

Sounds similar to this dram sharded matmul PCC issue https://github.com/tenstorrent/tt-metal/issues/10673. See latest comments by Alex regarding the potential fixes he is trying.

yugaoTT commented 3 weeks ago

could you try add a delay right after noc_async_read_tile_dram_sharded_with_state_with_trid ? That can slow down the dram read freq and see if dram reading too fast can cause hangs.

ttmtrajkovic commented 2 weeks ago

@yugaoTT, @davorchap

Based on the findings from @pavlepopovic, this problem shouldn't be related to didt and his debug indicated a problem with one of the DM functions. The bug should have didt label removed and should be prioritized for debug as it shows inconsistent behaviour across chips and may impact customer demos if this code is being tested by the customer team (@milank94)

@yugaoTT, further debug needs to be owned by you or someone else that added this API (if its not you). Please reassign if you disagree.

Milos

uaydonat commented 1 week ago

@yieldthought what is the workaround for this issue?

yieldthought commented 1 week ago

Reducing the number of columns in the weight matrix to below 32k seems to work around the issue, so we break it up into N smaller matmuls and execute them one at a time then concatenate the output tensors. It's a performance hit but we only pay it once per token not once per layer as only our LM head is this large.

yugaoTT commented 1 week ago

tested a slight different test - shape (32, 4096, 64128) on 64 cores. testing original code (allows up to 128 read reqs per trasaction id) - hang at the very end of 100 iters. Slowing down dram read loop (add empty loops between each read request) makes it passing for 100 iters. Make less read requests on the flight per trascation id (allows up to 64 read reqs per trasaction id) also make it passing. Don't use trascation ids make it passing.

Could be a timing issue related to trascation ids

yugaoTT commented 1 week ago

correction: after adding synchronize device, original code hangs at 17iter

yugaoTT commented 1 week ago

also tested on the code @yieldthought wrote above, it hangs at 10% for code on main. change line 344 in tt_metal/hw/inc/wormhole/noc_nonblocking_api.h to this (64 read reqs per txn id) while (NOC_STATUS_READ_REG(noc, NIU_MST_REQS_OUTSTANDING_ID(trid)) > ((NOC_MAX_TRANSACTION_ID_COUNT+1)/4)); make it passing for 100 iters.

davorchap commented 1 week ago

also tested on the code @yieldthought wrote above, it hangs at 10% for code on main. change line 344 in tt_metal/hw/inc/wormhole/noc_nonblocking_api.h to this (64 read reqs per txn id) while (NOC_STATUS_READ_REG(noc, NIU_MST_REQS_OUTSTANDING_ID(trid)) > ((NOC_MAX_TRANSACTION_ID_COUNT+1)/4)); make it passing for 100 iters.

Do we overflow NOCs outstanding req counter ? The counter only supports up to 64 outstanding transfers on each transaction id ?

yugaoTT commented 1 week ago

we have this overflow check, but previously set to 128 (ideally should be fine, since the registers are 8bit for each txn id)

yieldthought commented 4 days ago

Why was this (NOC_MAX_TRANSACTION_ID_COUNT+1)/2 in the first place? Why not wait until less than the max? Performance reasons?

yugaoTT commented 4 days ago

the number is from BUDA backend. it's a huristic number and I picked it up to metal.