Open yieldthought opened 1 month ago
Not sure why GitHub won't allow me to attach a python file, here is dram_sharded_hang.py
:
import math
import torch
import ttnn
import pytest
from tqdm import tqdm
@pytest.mark.timeout(300)
def test_dram_sharded_matmul(device, use_program_cache, reset_seeds):
# Model configuration
dim = 4096
vocab_size = 128256
split_size = vocab_size // 2
# Create dummy input
batch_size = 1
seq_len = 1
x = torch.randn(batch_size, seq_len, dim)
# Create dummy weights and split them
output_weight = torch.randn(vocab_size, dim)
output_weight_1 = output_weight[:split_size]
output_weight_2 = output_weight[split_size:]
# Perform PyTorch matmul for comparison
reference_output = torch.matmul(x, output_weight.t())
# Configure memory layout for output_weight (for both parts)
def create_output_mem_config(size):
# Calculate padded size to ensure it's divisible by (32 * 12)
padded_size = math.ceil(size / (32 * 12)) * (32 * 12)
if padded_size != size:
print(f"Original size: {size}, Padded size: {padded_size}")
shard_spec = ttnn.ShardSpec(
ttnn.CoreRangeSet({
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1)
)
}),
(4096, padded_size // 12),
ttnn.ShardOrientation.ROW_MAJOR,
False
)
return ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.DRAM,
shard_spec
)
# Convert output_weight parts to ttnn tensors
output_weight_ttnn_1 = ttnn.as_tensor(
output_weight_1.permute(1, 0),
device=device,
memory_config=create_output_mem_config(split_size),
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat8_b
)
output_weight_ttnn_2 = ttnn.as_tensor(
output_weight_2.permute(1, 0),
device=device,
memory_config=create_output_mem_config(split_size),
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat8_b
)
# Convert input to ttnn tensor
x_ttnn = ttnn.from_torch(
x,
device=device,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.create_sharded_memory_config(
(32, 4096 // 64), # Shard shape: [32, 64] -> 1 shard per core
ttnn.CoreGrid(y=8, x=8),
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
),
)
# Configure compute kernel
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
# Configure program
program_config = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=1,
per_core_M=1,
per_core_N=32, # 128256 / 2 / tile_size / core_count
fused_activation=None,
)
for i in tqdm(range(100)):
# Run the linear layers
output_1 = ttnn.linear(
x_ttnn,
output_weight_ttnn_1,
compute_kernel_config=compute_kernel_config,
program_config=program_config,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
)
output_2 = ttnn.linear(
x_ttnn,
output_weight_ttnn_2,
compute_kernel_config=compute_kernel_config,
program_config=program_config,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
)
output_1 = ttnn.sharded_to_interleaved(output_1)
output_2 = ttnn.sharded_to_interleaved(output_2)
# Concatenate the outputs
output = ttnn.concat([output_1, output_2], dim=-1)
# Convert output back to PyTorch tensor
ttnn_output_torch = ttnn.to_torch(output)
# Assertions
assert ttnn_output_torch.shape == (batch_size, seq_len, vocab_size), f"Expected output shape {(batch_size, seq_len, vocab_size)}, but got {ttnn_output_torch.shape}"
assert not torch.isnan(ttnn_output_torch).any(), "Output contains NaN values"
assert not torch.isinf(ttnn_output_torch).any(), "Output contains infinite values"
print("Output shape:", ttnn_output_torch.shape)
print("TTNN output sample:", ttnn_output_torch[0, 0, :10].tolist()) # Print first 10 elements of the TTNN output
print("Reference output sample:", reference_output[0, 0, :10].tolist()) # Print first 10 elements of the reference output
# Compare TTNN output with PyTorch matmul
pcc = ttnn.pearson_correlation_coefficient(ttnn_output_torch.flatten(), reference_output.flatten())
For context: this is Llama 3's LM head matmul. We have to split it into two because otherwise there's not enough L1 on the 12 DRAM-sharded cores to handle columns this large.
Tested this one on 7 machines in total (mix of N300 and N150), here are the results:
Removing matmul_block()
call from the compute kernel did NOT make the issue go away, indicating that this is not a di/dt issue that we know of (cores starting mm computation at the same time).
Removing semaphore calls from in0_sender kernel did not make the issue go away.
Removing noc_async_read_tile_dram_sharded_with_state_with_trid()
calls from in1_sender kernel did make the issue go away, on all machines.
So I believe this is either a problem with that function, or an underlying noc/dram issue.
FYI @ttmtrajkovic @yugaoTT @davorchap
Sounds similar to this dram sharded matmul PCC issue https://github.com/tenstorrent/tt-metal/issues/10673. See latest comments by Alex regarding the potential fixes he is trying.
could you try add a delay right after noc_async_read_tile_dram_sharded_with_state_with_trid ? That can slow down the dram read freq and see if dram reading too fast can cause hangs.
@yugaoTT, @davorchap
Based on the findings from @pavlepopovic, this problem shouldn't be related to didt and his debug indicated a problem with one of the DM functions. The bug should have didt label removed and should be prioritized for debug as it shows inconsistent behaviour across chips and may impact customer demos if this code is being tested by the customer team (@milank94)
@yugaoTT, further debug needs to be owned by you or someone else that added this API (if its not you). Please reassign if you disagree.
Milos
@yieldthought what is the workaround for this issue?
Reducing the number of columns in the weight matrix to below 32k seems to work around the issue, so we break it up into N smaller matmuls and execute them one at a time then concatenate the output tensors. It's a performance hit but we only pay it once per token not once per layer as only our LM head is this large.
tested a slight different test - shape (32, 4096, 64128) on 64 cores. testing original code (allows up to 128 read reqs per trasaction id) - hang at the very end of 100 iters. Slowing down dram read loop (add empty loops between each read request) makes it passing for 100 iters. Make less read requests on the flight per trascation id (allows up to 64 read reqs per trasaction id) also make it passing. Don't use trascation ids make it passing.
Could be a timing issue related to trascation ids
correction: after adding synchronize device, original code hangs at 17iter
also tested on the code @yieldthought wrote above,
it hangs at 10% for code on main.
change line 344 in tt_metal/hw/inc/wormhole/noc_nonblocking_api.h to this (64 read reqs per txn id)
while (NOC_STATUS_READ_REG(noc, NIU_MST_REQS_OUTSTANDING_ID(trid)) > ((NOC_MAX_TRANSACTION_ID_COUNT+1)/4));
make it passing for 100 iters.
also tested on the code @yieldthought wrote above, it hangs at 10% for code on main. change line 344 in tt_metal/hw/inc/wormhole/noc_nonblocking_api.h to this (64 read reqs per txn id)
while (NOC_STATUS_READ_REG(noc, NIU_MST_REQS_OUTSTANDING_ID(trid)) > ((NOC_MAX_TRANSACTION_ID_COUNT+1)/4));
make it passing for 100 iters.
Do we overflow NOCs outstanding req counter ? The counter only supports up to 64 outstanding transfers on each transaction id ?
we have this overflow check, but previously set to 128 (ideally should be fine, since the registers are 8bit for each txn id)
Why was this (NOC_MAX_TRANSACTION_ID_COUNT+1)/2 in the first place? Why not wait until less than the max? Performance reasons?
the number is from BUDA backend. it's a huristic number and I picked it up to metal.
Describe the bug Large DRAM-sharded matmuls cause wormhole to hang after a few iterations. Reducing the number of columns in the weight matrix to below 32k seems to work around the issue. Sometimes the board seems to need to be reset twice before it can be used again.
To Reproduce Steps to reproduce the behavior:
pytest dram_sharded_hang.py
Expected behavior Progress bar reaches 100 and pytest passes
Screenshots
Please complete the following environment information:
Additional context Doesn't happen on every iteration, Jasmina suggested we file as a suspected di/dt issue but could also be a race in DRAM reading.