tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
488 stars 80 forks source link

[Bug Report] `ttnn::reshape_on_device` failing to reshape from [1, 1, 2[32], 256] to [1, 2, 1, 256] #14922

Open marty1885 opened 2 weeks ago

marty1885 commented 2 weeks ago

Describe the bug

Bug discovered during bring up of Gemma 2B on my GGML backend, Gemma triggers a crash during reshape. With some debug prints I am able to narrow down the issue. Gemma asks to reshape a tensor from [1, 1, 2, 256] to [1, 2, 1, 256]. This is handled by ttnn.reshape_on_device as a it should work ccording to the docstring. Indicating the function always works if the last dimension of the source and destination tensor are the same.

https://github.com/tenstorrent/tt-metal/blob/b75f637ffc87422a609c529a08881908c92de26b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp#L56

The issue can be replicated with a simple TTNN program.

#include <cstddef>
#include <ttnn/core.hpp>
#include <ttnn/operations/eltwise/unary/unary.hpp>
#include <ttnn/operations/creation.hpp>
#include <ttnn/device.hpp>
#include <ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp>
#include "ttnn/operations/data_movement/reshape_on_device/reshape.hpp"

#include "common/bfloat16.hpp"

#include <vector>
#include <iostream>

int main()
{
    auto& device = ttnn::open_device(0);
    auto t = ttnn::zeros(ttnn::SimpleShape({1, 1, 2, 256})).to(&device);
    t = ttnn::tilize_with_zero_padding(t);

    auto q = ttnn::reshape_on_device(t, ttnn::SimpleShape({1, 2, 1, 256}));
}

with output:

                 Always | FATAL    | 8192 != 512
terminate called after throwing an instance of 'std::runtime_error'
  what():  TT_ASSERT @ /home/marty/Documents/tt/tt-metal/ttnn/cpp/ttnn/tensor/tensor_ops.cpp:341: input_tensor.volume() == new_padded_shape.volume()
info:
8192 != 512
backtrace:
 --- tt::tt_metal::Tensor::reshape(ttnn::types::Shape const&) const
 --- ttnn::operations::data_movement::ReshapeOperation::invoke(unsigned char, tt::tt_metal::Tensor const&, ttnn::types::Shape, std::optional<tt::tt_metal::MemoryConfig> const&)
 --- ttnn::operations::data_movement::ReshapeOperation::invoke(tt::tt_metal::Tensor const&, ttnn::types::Shape const&)
 --- ./ttnn-hello(+0x30211) [0x643c722cb211]
 --- ./ttnn-hello(+0x30287) [0x643c722cb287]
 --- ./ttnn-hello(+0x3038d) [0x643c722cb38d]
 --- ./ttnn-hello(+0x13724) [0x643c722ae724]
 --- /usr/lib/libc.so.6(+0x25e08) [0x7bbcdb234e08]
 --- /usr/lib/libc.so.6(__libc_start_main+0x8c) [0x7bbcdb234ecc]
 --- ./ttnn-hello(+0x11b35) [0x643c722acb35]

To Reproduce Steps to reproduce the behavior:

  1. Compile and run the provided example
  2. Observe the error

Expected behavior reshape_on_device should have reshaped the tensor correctly

Screenshots If applicable, add screenshots to help explain your problem.

Please complete the following environment information:

Additional context

If reverent, this the the debug print and the log in GGML that leads to the bug discovery.

if(tensor.shape()[-1] == (uint32_t)node->ne[0]) {
    // Fast path. reshape_on_device() can reshape is both the last dimension is the same 
    std::cerr << "Fast path 2 in reshape_tt_tensor_into_ggml() for tensor " << node->name << std::endl;
    std::cerr << "  tensor shape: " << tensor.shape() << std::endl;
    std::cerr << "  target shape: " << ttnn::SimpleShape(target_shape) << std::endl;
    return ttnn::reshape_on_device(tensor, ttnn::SimpleShape(target_shape));
}

Log:

Fast path 2 in reshape_tt_tensor_into_ggml() for tensor Kcur-17 (reshaped)
  tensor shape: ttnn.Shape([1, 1, 2[32], 256])
  target shape: SimpleShape([1, 2, 1, 256])
                 Always | FATAL    | 8192 != 512
terminate called after throwing an instance of 'std::runtime_error'
  what():  TT_ASSERT @ /home/marty/Documents/tt/tt-metal/ttnn/cpp/ttnn/tensor/tensor_ops.cpp:341: input_tensor.volume() == new_padded_shape.volume()
info:
8192 != 512
jvegaTT commented 1 week ago

reshape_on_device will be deprecated soon and replaced with an improved reshape call. The goal is reshape will either be a 0 cost view or on device padding aware with no possible host fallback. The above function should definitely be supported in the new version, but I believe currently the required added padding causes an issue as reshape_on_device is not padding aware. Note that in tiled tensors [1,1,2,256] requires 8 tiles but [1,2,1,256] requires 16 tiles with the excess being all extra padding that needs to be added which reshape_on_device can not do.

marty1885 commented 1 week ago

Thanks. I'll close the issue once reshape_on_device is deprecated. Hope we can have better docs in the future.