tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 74 forks source link

ttnn.div_bw with round_mode in ['trunc', 'floor'] throws "Unsupported DataType" fatal error when when grad_tensor is not bfloat8_b and when input tensors a and b are not bfloat16 #13976

Closed amalbasaTT closed 2 days ago

amalbasaTT commented 3 weeks ago

Describe the bug ttnn.div_bw with round_mode in ['trunc', 'floor'] works only when grad_dtype is bfloat8_b, and input_a_dtype and input_b_dtype are bfloat16

To Reproduce Steps to reproduce the behavior: Sweep test for div_bw is located in 'tests/sweep_framework/sweeps/eltwise/binary_backward/div_bw/div_bw.py'

  1. Checkout branch amalbasaTT/backward_ops-sweeps-4 (soon to be merged to main)

  2. Go to 'tests/sweep_framework/sweeps/eltwise/binary_backward/div_bw/div_bw.py'

  3. In "xfail" suite:

    • replace "round_mode": ["None", "floor", "trunc"], with "round_mode": [ "floor", "trunc"],
  4. Generate new parameter vectors and run the sweep test

    python3 tests/sweep_framework/sweeps_parameter_generator.py --elastic cloud --module-name eltwise.binary_backward.div_bw.div_bw
    python3 tests/sweep_framework/sweeps_runner.py --elastic cloud --module-name eltwise.binary_backward.div_bw.div_bw --suite-name xfail
  5. See the error. Results can be found on elastic cloud as explained here: https://github.com/tenstorrent/tt-metal/tree/main/tests/sweep_framework

umadevimcw commented 3 days ago

@amalbasaTT Can you provide the simple unit test example / simple python code to produce the error you are describing.

amalbasaTT commented 2 days ago

@amalbasaTT Can you provide the simple unit test example / simple python code to produce the error you are describing.

It seems that it was fixed in the meantime, but I'll provide the unit test:

# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
from functools import partial
import pytest
import torch
import ttnn
import traceback

from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt
from models.utility_functions import torch_random

def run_backward_div_tests(
    input_shape,
    round_mode,
    dtype,
    dlayout,
    in_mem_cfg,
    out_mem_cfg,
    data_seed,
    device,
):
    torch.manual_seed(data_seed)
    # grad tensor
    x = gen_func_with_cast_tt(
        partial(torch_random, low=-100, high=100, dtype=torch.float32), dtype[0]
    )(input_shape[0])
    # input tensor 
    y = gen_func_with_cast_tt(
        partial(torch_random, low=-100, high=100, dtype=torch.float32), dtype[1]
    )(input_shape[0])
    # other input tensor
    z = gen_func_with_cast_tt(
        partial(torch_random, low=0.1, high=100, dtype=torch.float32), dtype[2]
    )(input_shape[0])
    signs_b = torch.randint(0, 2, input_shape[0]) * 2 - 1
    z *= signs_b

    y.requires_grad = True
    z.requires_grad = True

    try:
        # get ref result
        golden_function = ttnn.get_golden_function(ttnn.div_bw)
        ref_values = golden_function(x, y, z, round_mode if round_mode != "None" else None)

        tt_x = ttnn.from_torch(x, dtype=dtype[0], layout=dlayout[0], device=device, memory_config=in_mem_cfg[0])
        tt_y = ttnn.from_torch(y, dtype=dtype[1], layout=dlayout[0], device=device, memory_config=in_mem_cfg[1])
        tt_z = ttnn.from_torch(z, dtype=dtype[2], layout=dlayout[0], device=device, memory_config=in_mem_cfg[2])

        tt_results = ttnn.div_bw(tt_x, tt_y, tt_z, round_mode=round_mode, memory_config=out_mem_cfg)

    except Exception as e:
        logger.warning(f"Test execution crashed: {e}")
        print(traceback.format_exc())
        raise e

    for i in range(len(ref_values)):
        tt_result, ref_value = ttnn.to_torch(tt_results[i]), ref_values[i]
        assert len(tt_result.shape) == len(ref_value.shape)
        assert tt_result.shape == ref_value.shape
        assert_with_pcc(ref_value, tt_result, 0.999)

test_sweep_args = [
    (
        [(3, 2, 192, 32)],
        "floor",
        [ttnn.bfloat16, ttnn.bfloat8_b, ttnn.bfloat8_b],
        [ttnn.TILE_LAYOUT],
        [ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG],
        ttnn.DRAM_MEMORY_CONFIG,
        11079580,
    ),
]

@pytest.mark.parametrize(
    "input_shape, round_mode, dtype, dlayout, in_mem_config, out_mem_config, data_seed",
    (test_sweep_args),
)
def test_backward_div(input_shape, round_mode, dtype, dlayout, in_mem_config, out_mem_config, data_seed, device):
    run_backward_div_tests(input_shape, round_mode, dtype, dlayout, in_mem_config, out_mem_config, data_seed, device)
umadevimcw commented 2 days ago

@amalbasaTT If that's the case can we close this issue?

KalaivaniMCW commented 2 days ago

The unit test provided passes on latest mainImage

amalbasaTT commented 2 days ago

It seems it is fixed, you can close the issue then.

umadevimcw commented 2 days ago

Sure thanks