tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
400 stars 50 forks source link

[Bug Report] TTNN Embedding op (ttnn.embedding) fails when configured as tile layout #12644

Open nvukobratTT opened 6 days ago

nvukobratTT commented 6 days ago

Describe the bug Embedding op (ttnn.embedding) fails on implicit reshape when TILE_LAYOUT is used.

For more context, here is the exact error message:

Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first.

To Reproduce For full repro, you can use the following code:

import pytest
import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random

def test_embedding_llama3_tile(device):
    torch.manual_seed(1234)

    # Create Torch Input tensor(1, 12) and weight tensor(32000, 3200) and Run the torch.nn.functional.embedding op inference
    torch_input_tensor = torch.randint(0, 32000, (1, 12)).to(torch.int32)
    torch_weights = torch_random((32000, 3200), -0.1, 0.1, dtype=torch.bfloat16)
    torch_output_tensor = torch.nn.functional.embedding(torch_input_tensor, torch_weights)

    # Create Torch Input tensor(1, 12) and weight tensor(32000, 3200) and Run the ttnn.embedding op inference
    input_tensor = ttnn.to_device(
        ttnn.from_torch(torch_input_tensor, dtype=ttnn.uint32, layout=ttnn.TILE_LAYOUT),
        device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    weights = ttnn.to_device(
        ttnn.from_torch(torch_weights, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT),
        device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    output_tensor = ttnn.embedding(input_tensor, weights, memory_config=ttnn.DRAM_MEMORY_CONFIG)

    # Convert ttnn to torch tensor
    output_tensor = ttnn.to_torch(output_tensor)

    assert_with_pcc(torch_output_tensor, output_tensor)

Note: ROW_MAJOR_LAYOUT will pass successfully. E.g.

import pytest
import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random

def test_embedding_llama3_row_major(device):
    torch.manual_seed(1234)

    # Create Torch Input tensor(1, 12) and weight tensor(32000, 3200) and Run the torch.nn.functional.embedding op inference
    torch_input_tensor = torch.randint(0, 32000, (1, 12)).to(torch.int32)
    torch_weights = torch_random((32000, 3200), -0.1, 0.1, dtype=torch.bfloat16)
    torch_output_tensor = torch.nn.functional.embedding(torch_input_tensor, torch_weights)

    # Create Torch Input tensor(1, 12) and weight tensor(32000, 3200) and Run the ttnn.embedding op inference
    input_tensor = ttnn.to_device(
        ttnn.from_torch(torch_input_tensor, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT),
        device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    weights = ttnn.to_device(
        ttnn.from_torch(torch_weights, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT),
        device,
        memory_config=ttnn.DRAM_MEMORY_CONFIG,
    )
    output_tensor = ttnn.embedding(input_tensor, weights, memory_config=ttnn.DRAM_MEMORY_CONFIG)

    # Convert ttnn to torch tensor
    output_tensor = ttnn.to_torch(output_tensor)

    # PCC check
    assert_with_pcc(torch_output_tensor, output_tensor)

Expected behavior The embedding op is passed on to the tile layout as well.

Please complete the following environment information:

tt-mpantic commented 10 hours ago

Seems to be duplicate of 12866