tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 75 forks source link

Mul op returns wrong output when multiplied with NaN #12776

Open ruthreshx opened 2 months ago

ruthreshx commented 2 months ago

Mul op returns wrong output:

Actual Output: NaN * 0.5 = returns 255211775190703847597530955573826158592.00000

Expected Output: NaN * 0.5 = NaN

Also we have tried the nan_gs value from the device 6.9752e19 * 0.5 it returns 34875875514357121024.00000

Note: I noticed one more thing, If NaN multiplies with any decimal value from 0.0 to 0.66 it returns above bigger value, otherwise it returns NaN as expected.

Observed in Both GS and WH

SeanNijjar commented 1 month ago

Redirecting as I am not the right person to look at this. Gave to @ttmtrajkovic to redirect

ttmtrajkovic commented 1 month ago

@ruthreshx Which op is this, please specify the test you were using and how to reproduce the problem

ruthreshx commented 1 month ago

Hi @ttmtrajkovic , Please find the below test to reproduce the issue. Device: GS, WH300 Op: MUL

# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import random
import pytest
import torch
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.ttnn.python_api_testing.sweep_tests import ttnn_ops

def run_eltwise_mul_tests(input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed, device):
    torch.manual_seed(data_seed)

    x = torch.Tensor(size=input_shape[0]).uniform_(0, 0.66).to(torch.bfloat16)
    y = torch.full(size=input_shape[1], fill_value=float('nan'))

    try:
        # get ref result
        ref_value = torch.mul(x, y)

        x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config, dtype[0])
        y = ttnn_ops.setup_ttnn_tensor(y, device, dlayout[1], in_mem_config, dtype[1])

        tt_result = ttnn.mul(x, y)
        tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config)

        print("tt_result ==> ", tt_result)
        logger.info(f"Mul run {input_shape[0]} * {input_shape[1]} finished")

    except Exception as e:
        logger.warning(f"Operation execution crashed")
        raise e

    assert len(tt_result.shape) == len(ref_value.shape)
    assert tt_result.shape == ref_value.shape
    assert_with_pcc(ref_value, tt_result, 0.99)

test_sweep_args = [
    (
        [(32, 32), (32, 32)],
        [ttnn.bfloat16, ttnn.bfloat16],
        [ttnn.TILE_LAYOUT, ttnn.TILE_LAYOUT],
        (ttnn.DRAM_MEMORY_CONFIG),
        (ttnn.DRAM_MEMORY_CONFIG),
        17799073,
    ),
]

def test_eltwise_mul(device):
    for i in range(1):
        for input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed in test_sweep_args:
            run_eltwise_mul_tests(input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed, device)

Let me know If anything requires :)