Open marty1885 opened 2 weeks ago
reshape_on_device will be deprecated soon and replaced with an improved reshape call. The goal is reshape will either be a 0 cost view or on device padding aware with no possible host fallback. The above function should definitely be supported in the new version, but I believe currently the required added padding causes an issue as reshape_on_device is not padding aware. Note that in tiled tensors [1,1,2,256] requires 8 tiles but [1,2,1,256] requires 16 tiles with the excess being all extra padding that needs to be added which reshape_on_device can not do.
Thanks. I'll close the issue once reshape_on_device
is deprecated. Hope we can have better docs in the future.
Describe the bug
Bug discovered during bring up of Gemma 2B on my GGML backend, Gemma triggers a crash during reshape. With some debug prints I am able to narrow down the issue. Gemma asks to reshape a tensor from
[1, 1, 2, 256]
to[1, 2, 1, 256]
. This is handled byttnn.reshape_on_device
as a it should work ccording to the docstring. Indicating the function always works if the last dimension of the source and destination tensor are the same.https://github.com/tenstorrent/tt-metal/blob/b75f637ffc87422a609c529a08881908c92de26b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape_pybind.cpp#L56
The issue can be replicated with a simple TTNN program.
with output:
To Reproduce Steps to reproduce the behavior:
Expected behavior
reshape_on_device
should have reshaped the tensor correctlyScreenshots If applicable, add screenshots to help explain your problem.
Please complete the following environment information:
Additional context
If reverent, this the the debug print and the log in GGML that leads to the bug discovery.
Log: