tenstorrent / pytorch2.0_ttnn

⭐️ TTNN Compiler for PyTorch 2.0 ⭐️ It enables running PyTorch2.0 models on Tenstorrent hardware
https://tenstorrent.github.io/tt-metal/latest/ttnn/
25 stars 6 forks source link

`ttnn.permute` fails with inner-most dim = 1 when doing transpose #377

Open jerrysky3 opened 2 days ago

jerrysky3 commented 2 days ago

ttnn.permute fails with inner-most dim = 1 when doing transpose (tested with 2D tensors, e.g. (2, 1) with permutation (1, 0)).

This is blocking the aten.t.default to ttnn.permute conversion in models:

To reproduce

Run the example below:

import torch
import ttnn

def main(device):
    torch_tensor = torch.rand((2, 1), dtype=torch.bfloat16)
    print(torch_tensor)
    input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, device=device)
    input_tensor = ttnn.permute(input_tensor, (1, 0))
    print(ttnn.to_torch(input_tensor), input_tensor.shape)

if __name__ == "__main__":
    device = ttnn.open_device(device_id=0)
    try:
        main(device)
    finally:
        ttnn.close_device(device)

Currently it will fail with the error:

RuntimeError: TT_FATAL @ ../ttnn/cpp/ttnn/tensor/types.cpp:211: normalized_index >= 0 and normalized_index < rank
info:
Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615
backtrace:
 --- /home/jerry/tt-metal/ttnn/ttnn/_ttnn.so(+0x100ef49) [0x7f53b10def49]
 --- ttnn::types::Shape::operator[](long) const
 --- ttnn::operations::data_movement::ReshapeViewOperation::invoke(tt::tt_metal::Tensor const&, ttnn::types::Shape const&)
 --- /home/jerry/tt-metal/ttnn/ttnn/_ttnn.so(+0x4e7f92) [0x7f53b05b7f92]
 --- ttnn::operations::data_movement::ExecutePermute::invoke(unsigned char, tt::tt_metal::Tensor const&, tt::stl::Span<long const, 18446744073709551615ul>, std::__1::optional<tt::tt_metal::MemoryConfig> const&, bool)
 --- /home/jerry/tt-metal/ttnn/ttnn/_ttnn.so(+0x12a60cb) [0x7f53b13760cb]
 --- /home/jerry/tt-metal/ttnn/ttnn/_ttnn.so(+0x11fad8b) [0x7f53b12cad8b]
 --- /home/jerry/tt-metal/ttnn/ttnn/_ttnn.so(+0x12511ce) [0x7f53b13211ce]
 --- python(PyCFunction_Call+0x59) [0x5e66b9]
 --- python(_PyObject_MakeTpCall+0x29e) [0x5e728e]
 --- python() [0x4f9588]
 --- python(PyObject_Call+0x62) [0x5e5e32]
 --- python() [0x58db4c]
 --- python(PyObject_Call+0x25e) [0x5e602e]
 --- python(_PyEval_EvalFrameDefault+0x1f34) [0x55e124]
 --- python(_PyEval_EvalCodeWithName+0x26a) [0x55abda]
 --- python(_PyFunction_Vectorcall+0x393) [0x5e6c43]
 --- python() [0x58d7be]
 --- python(_PyObject_MakeTpCall+0x29e) [0x5e728e]
 --- python(_PyEval_EvalFrameDefault+0x5dac) [0x561f9c]
 --- python(_PyFunction_Vectorcall+0x1b6) [0x5e6a66]
 --- python(_PyEval_EvalFrameDefault+0x72d) [0x55c91d]
 --- python(_PyEval_EvalCodeWithName+0x26a) [0x55abda]
 --- python(PyEval_EvalCode+0x27) [0x68bfe7]
 --- python() [0x67d831]
 --- python() [0x67d8af]
 --- python() [0x67d951]
 --- python(PyRun_SimpleFileExFlags+0x197) [0x67e5e7]
 --- python(Py_RunMain+0x212) [0x6b5732]
 --- python(Py_BytesMain+0x2d) [0x6b5abd]
 --- /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf3) [0x7f541e73a083]
 --- python(_start+0x2e) [0x5eb5ee]

                 Device | INFO     | Closing user mode device drivers
jerrysky3 commented 2 days ago

A potential workaround is to use ttnn.reshape for this case to reshape (x, 1) to (1, x)