tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
480 stars 78 forks source link

Creating a tensor with sharded memory config using ttnn.create_sharded_memory_config produces the tensor with low pcc compared to input tensor when using tensor height and width as shard shape and when number of shards exceeds specified core size #15306

Open amalbasaTT opened 1 day ago

amalbasaTT commented 1 day ago

Describe the bug Creating a tensor with sharded memory config using ttnn.create_sharded_memory_config with arguments: 1) use_height_and_width_as_shard_shape = True 2) input_shape = shape_of_the_input_tensor, where last two dimensions are used as shape of the shard and the remaining are collapsed into a number of shards 3) orientation = ttnn.ShardOrientation.ROW_MAJOR creates a tensor with low PCC when compared to the original input tensor in all cases when using PCC of 1.0 as a threshold and in most of the cases when using PCC threshold 0.999. When using the sharded tensor as an input to ttnn.relu, PCC is low in all cases. Problem is observed on Wormhole_B0.

To Reproduce Steps to reproduce the behavior:

  1. Checkout branch amalbasaTT/unary_sharded-sweeps-2 (soon to be merged to main)
  2. Copy the unit test below to test_relu_sharded.py:
    
    import torch
    import random
    import ttnn
    import itertools
    import pytest
    import traceback
    import math
    from loguru import logger
    from functools import partial

from tests.sweep_framework.sweep_utils.utils import gen_shapes, get_device_grid_size, get_sharded_config from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, _gen_reshape_args_from_volume from tests.ttnn.utils_for_testing import check_with_pcc from models.utility_functions import torch_random

Y, X = get_device_grid_size() DEVICE_GRID_SIZE = ttnn.CoreGrid(y=Y, x=X)

def gen_sharded_spec( gen_unsafe, num_shapes, num_core_samples, ):
def is_unsafe(num_of_shards, core_y, core_x): return num_of_shards > (core_y * core_x)

for i in range(num_core_samples):
    y = random.randint(1, Y)
    x = random.randint(1, X)
    for j in range(num_shapes):
        for rank in [3, 4]:
            for dtype in [ttnn.bfloat16, ttnn.bfloat8_b]:
                data_seed = random.randint(0, 20000000)

                min_tensor_height = 32
                min_tensor_width = 32
                mul_height = random.randint(1, 10)
                mul_width = random.randint(1, 10)
                tensor_height = min_tensor_height * mul_height
                tensor_width = min_tensor_width * mul_width
                input_shape = [tensor_height, tensor_width]

                #num_shards corresponds to the product of rest of the dims of input_tensor
                num_shards = random.randint(1, 100)

                if gen_unsafe:
                    while not is_unsafe(num_shards, y, x):
                        num_shards = random.randint(1, 100)
                else:
                    while is_unsafe(num_shards, y, x):
                        num_shards = random.randint(1, 100)

                rest_dims = random.choice(_gen_reshape_args_from_volume(num_shards, step=1, out_dims=rank - 2))
                rest_dims = list(rest_dims["reshape_dims"])
                input_shape = rest_dims + input_shape

                device_grid_size = ttnn.CoreGrid(y=y, x=x)
                mem_cfg = get_sharded_config(input_shape, "tensor_hw", device_grid_size, "row_major")

                yield (input_shape, dtype, ttnn.TILE_LAYOUT, mem_cfg, data_seed)

test_sweep_args = list(gen_sharded_spec(True, 4, 4))

def run_relu_sharded_tests( input_shape, dtype, dlayout, mem_cfg, data_seed, device, ): torch.manual_seed(data_seed)

x = gen_func_with_cast_tt(
    partial(torch_random, low=-100, high=100, dtype=torch.bfloat16), dtype
)(input_shape)

try:
    ref_value = torch.nn.functional.relu(x)

    tt_x = ttnn.from_torch(
        x,
        dtype=dtype,
        layout=dlayout,
        device=device,
        memory_config=mem_cfg,
    )

    tt_result = ttnn.relu(tt_x, memory_config=mem_cfg)
    tt_result = ttnn.to_torch(tt_result)

except Exception as e:
    logger.warning(f"Test execution crashed: {e}")
    print(traceback.format_exc())
    raise e

passed, output_str = check_with_pcc(x, ttnn.to_torch(tt_x), 1.0)
assert passed, f"Failed before ttnn.relu {output_str}, {input_shape}, {dtype}, {mem_cfg.shard_spec}"
passed, output_str = check_with_pcc(ref_value, tt_result, 0.999)
assert passed, f"Failed at ttnn.relu, {output_str}, {input_shape}, {dtype}, {mem_cfg.shard_spec}"

@pytest.mark.parametrize( "input_shape, dtype, dlayout, mem_cfg, data_seed", (test_sweep_args), ) def test_relu_sharded(input_shape, dtype, dlayout, mem_cfg, data_seed, device): run_relu_sharded_tests(input_shape, dtype, dlayout, mem_cfg, data_seed, device)

3. Run it with command:

pytest path/to/test_relu_sharded.py



**Expected behavior**
When setting parameter `gen_unsafe` of function `gen_sharded_spec` to True, all test cases fail. Otherwise, all test cases pass.