tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
410 stars 51 forks source link

ttnn.multiply/add with internal broadcasting along dim=2 is pretty slow #10623

Open kpaigwar opened 2 months ago

kpaigwar commented 2 months ago

Apparently in some models such as Mamba, there are a few eltwise_operations (add/ multiply) between input_tensors and weights. Typically, input_tensors have a batch_dim>1 and weights don't have a notion of batch_dim. This requires broadcasting of weights along batch_dim before eltwise operations.

Presently, we support internal broadcasting of weights in ttnn but these operations are pretty slow. This has been validated by comparing performance with pre-broadcasted weights in the below unit test. As can be seen in the table, multiply with pre-broadcasted weights is 32x faster. However, pre-broadcasting of weights is not a tractable solution as this will reduce the DRAM space.

Op Device Duration(ns)
ttnn.multply (pre-broadcasted weights) 4442
ttnn.multply (no pre-broadcasted weights) 132336
ttnn.bcast (no pre-broadcasted weights) 132656
import ttnn
import torch
from tests.ttnn.utils_for_testing import assert_with_pcc

def test_multiply_perf(device):
    batch_size = 32
    hidden_size = 5120
    input_shape = [1, 1, batch_size, hidden_size]
    torch_input = torch.randn(input_shape)
    weights_shape = [1, 1, 1, hidden_size]
    torch_weights = torch.randn(weights_shape)

    tt_input = ttnn.from_torch(torch_input,
                             device=device, layout=ttnn.TILE_LAYOUT, 
                             memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16)
    tt_weights = ttnn.from_torch(torch_weights,
                                device=device, layout=ttnn.TILE_LAYOUT, 
                                memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat8_b)
    tt_weights_pre_bcasted = ttnn.from_torch(torch_weights.repeat(1, 1, batch_size, 1),
                                device=device, layout=ttnn.TILE_LAYOUT, 
                                memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat8_b)

    # Act [1, 1, user, E] * [1, 1, 1, E] -> [1, 1, user, E]
    torch_output = torch_input * torch_weights
    mul_out1 = ttnn.multiply(tt_input, tt_weights_pre_bcasted, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16)
    mul_out2 = ttnn.multiply(tt_input, tt_weights, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16)
    mul_out3 = ttnn.experimental.operations.primary.bcast(
                    tt_input,
                    tt_weights,
                    ttnn.experimental.tensor.BcastOpMath.MUL,
                    ttnn.experimental.tensor.BcastOpDim.H,
                    output_mem_config=ttnn.L1_MEMORY_CONFIG,
                )
    assert_with_pcc(torch_output, ttnn.to_torch(mul_out1).to(torch.float32), pcc=0.999)
    assert_with_pcc(torch_output, ttnn.to_torch(mul_out2).to(torch.float32), pcc=0.999)
    assert_with_pcc(torch_output, ttnn.to_torch(mul_out3).to(torch.float32), pcc=0.999)

Perf Sheet ttnn_multply_perf.csv

kpaigwar commented 2 months ago

fyi @esmalTT @uaydonat

uaydonat commented 2 months ago

Are different batches running on different cores?

What's the size of the weights?

Does it mean, in the unoptimized case, every core is reading the same data from dram? So there num_batches times more dram reads?

kpaigwar commented 2 months ago

Weights Shape is [1, 1, 1, hidden_size], where hidden_size = 5120 and batch_size=32 The multiply kernel is using all 64 cores from the perf sheet. I think the work split is done along hidden_dim. With each core having activation of size [1, 1, 32, 80].

kpaigwar commented 2 months ago

I don't think batch number of times DRAM reads Making it slow. Cause in the case of pre-broacasted weights there will be same number of DRAM reads and its faster.