tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
417 stars 54 forks source link

softmax fails with padded tile layout #4599

Open eyonland opened 9 months ago

eyonland commented 9 months ago

The softmax in tt-lib will fail if the tensor is in tile layout with a shape of 8,2,2 padded to 8,32,32 using a padded value of zero instead of -inf.

We have a workaround, at the moment we fallback to moreh softmax in core.py with:

is_padded_and_using_tile = (
        input_tensor.layout == TILE_LAYOUT and list(input_tensor.shape)[-2:] != list(input_tensor.shape.padded())[-2:]
    )

Here is a sample test from ttnn

@skip_for_wormhole_b0()
def test_softmax_with_padded_tile_layout(device):
    torch.manual_seed(0)
    torch_input_tensor = torch_random((8, 2, 2), -10, 10, dtype=torch.bfloat16)
    torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16)
    input_tensor = ttnn.from_torch(torch_input_tensor)
    input_tensor = ttnn.to_layout(input_tensor, ttnn.TILE_LAYOUT)
    input_tensor = ttnn.to_device(input_tensor, device)
    output_tensor = ttnn.softmax(input_tensor, dim=-1)
    output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
    output_tensor = ttnn.from_device(output_tensor)
    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor, 0.997)
tt-aho commented 8 months ago

The issue here is that most ttl ops do not understand/check for padded shapes. They interpret the passed in data as if that were the true shape as a lot of these ops were written before padded shape was supported/widespread. So if you pass in [8, 2, 2] in RM, the ttl softmax would pad internally and run, but when you pass in a padded [8, 32, 32], it interprets the 32 as the actual width.

We haven't gone back to update most ops to understand padded/true shapes after this was added, so this issue will mostly pop up on ops that do normalizing/computations on dims like ln, sm, etc.

yugaoTT commented 6 months ago

how should we add the support for padded shapes? I guess can we need to get the unpadded shape, and calculate the scalers on host

mtairum commented 4 months ago

@eyonland @yugaoTT , What's the status of this issue?

We use a padded tile layout input for the softmax in Mixtral and we're seeing wrong softmax outputs (last dim) when the input in sharded. Specifically the issue seems to be when the sharding is by height.

I wrote a small-ish test using similar sizes of our model showcasing the issue.

The shard config I'm using is the following:

    shard_config = ttnn.create_sharded_memory_config(
        shape=(1024, 32),
        # shape=(32, 32),
        core_grid=ttnn.CoreGrid(y=4, x=8),
        strategy=ttnn.ShardStrategy.WIDTH,
        # strategy=ttnn.ShardStrategy.HEIGHT,
        orientation=ttnn.ShardOrientation.ROW_MAJOR,
        use_height_and_width_as_shard_shape=True,
    )

The tensor to be sharded is of shape [1,32,32,32]. If sharding by height, with shape (32,32) we're seeing bad PCC below. However, if sharding by Width instead with shape (1024, 32) PCC will be correct, when compared to interleaved approach.

After sharding, we slice the input input = input[:, :, :, :i+1] resulted in padded tile layout and then feed this into softmax.

import torch
import pytest
from loguru import logger

import ttnn
from models.utility_functions import comp_pcc

def test_softmax(device, use_program_cache, reset_seeds):
    iterations = 3

    pt_input = (torch.rand(1,32, 32, 32))
    tt_input_il = ttnn.from_torch(pt_input, dtype=ttnn.bfloat16, device=device, layout= ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
    tt_input_sharded =  ttnn.from_torch(pt_input, dtype=ttnn.bfloat16, device=device, layout= ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)

    shard_config = ttnn.create_sharded_memory_config(
        shape=(1024, 32),
        # shape=(32, 32),
        core_grid=ttnn.CoreGrid(y=4, x=8),
        strategy=ttnn.ShardStrategy.WIDTH,
        # strategy=ttnn.ShardStrategy.HEIGHT,
        orientation=ttnn.ShardOrientation.ROW_MAJOR,
        use_height_and_width_as_shard_shape=True,
    )

    # Interleaved
    for i in range(iterations):
        tt_input_il = tt_input_il[:, :, :, :i+1]
        soft_il = ttnn.softmax(tt_input_il, dim=-1)

    # Sharded
    tt_input_sharded = ttnn.to_memory_config(tt_input_sharded, memory_config=shard_config)
    for i in range(iterations):
        tt_input_sharded = tt_input_sharded[:, :, :, :i+1]
        soft_shard = ttnn.softmax(tt_input_sharded, dim=-1)

    passing, pcc_message = comp_pcc(ttnn.to_torch(soft_il), ttnn.to_torch(soft_shard))
    assert passing, f"PCC = {pcc_message}"
yugaoTT commented 4 months ago

we support padding in interleaved softmax now, not sharded currently. but your case does not requires padding. height sharding should be the correct way to do it, if you want your workload to spread to 32 cores.

yugaoTT commented 4 months ago

see a relavant test here,


@pytest.mark.parametrize(
    "in0_mem_config",
    (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),),
    ids=[
        "in0_DRAM",
    ],
)
@pytest.mark.parametrize(
    "in_dtype",
    (
        ttl.tensor.DataType.BFLOAT8_B,
    ),
    ids=["BFLOAT8_B"],
)
def test_softmax(device, in_dtype, in0_mem_config):
    torch.manual_seed(0)
    sm_op =  ttl.operations.primary.softmax_in_place

    grid_size = [8, 4]
    input_shape = (1, 32, 32, 32)

    input_tensor = torch.randn(input_shape).bfloat16().float()
    in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype)
    in1_t_shard = ttl.tensor.interleaved_to_sharded(
        in1_t,
        grid_size,
        [32, 32],
        ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
        ttl.tensor.ShardOrientation.ROW_MAJOR,
    )

    program_config = ttl.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig(
        compute_with_storage_grid_size=grid_size,
        subblock_w=1,
        block_h=1,
        block_w=1,
    )

    tt_output_sharded = sm_op(
        in1_t_shard, program_config=program_config
    )

    tt_output = ttl.tensor.sharded_to_interleaved(tt_output_sharded, in0_mem_config)
    tt_output_tensor = tt_output.cpu().to_torch().float()
    tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape)
    tt_output_tensor = untilize(tt_output_tensor)

    golden_output_tensor = torch.softmax(input_tensor, dim=-1)

    allclose, output = comp_pcc(
        tt_output_tensor,
        golden_output_tensor,
    )
    logger.info(output)
    assert allclose, f"FAILED: {output}"
mtairum commented 4 months ago

we support padding in interleaved softmax now, not sharded currently. but your case does not requires padding. height sharding should be the correct way to do it, if you want your workload to spread to 32 cores.

Thank you for the update!

In our case we do padding since we don't have an attention mask.