tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
478 stars 78 forks source link

[Bug Report] Reshape being used in an incorrect matter #15317

Open jvegaTT opened 1 day ago

jvegaTT commented 1 day ago

Describe the bug A clear and concise description of what the bug is.

In Torch's implementation of reshape, the condition input_shape.volume()==output_shape.volume() must match. However some OPs such as slice and embedded have been using slice as a wrapper to call tensor.reshape() specifically when dealing with Tile layouts and as a result call it where the above condition is not met (which can lead to the introduction of padding data into the "real" data which is their intention).

The current implementation has the following lines to catch those cases:

bool tile_tensor_view_reshape_possible = (layout == ttnn::Layout::TILE and ((shape.with_tile_padding()[-2] % tile_second_dim == 0) and (shape.with_tile_padding()[-1] % tile_second_dim == 0)) and (tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1])); if (!(ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) or tile_tensor_view_reshape_possible) { return tensor.reshape(shape); } However this can introduce other issues for instance the reshape:

Shape{1[32],32[32]} -> Shape{2,1[32],16[32]}

Is perfectly valid in terms of volume but it is not a 0 cost view as we are going from a shape needing 1 Tile it 32 valid values in each tile to a shape needing 2 Tiles with 16 valid values in each tile. The above equation though would incorrectly flag that this is a 0 cost view

The short term solution is to check if the volumes do not match and to only call the above in those cases. However this means we are still using reshape as a hack (other OPs should be calling tensor.reshape themselves when valid, reshape should only be used when not hacking it) and even then there are possible bugs like:

Shape{1[32],31[32]} -> Shape{2,1[32],16[32]}

Where the volumes don't match and also a view is not possible, this will lead to a TT_ASSERT failure in tensor. In this case the user should have called tensor.reshape themselves from Shape{1[32],31[32]} -> Shape{1,1[32],32[32]} and then used reshape to go to Shape{2,1[32],16[32]} in a valid configurations.

The real solution is for the individual users to not use reshape as a hack and to be able to simply have a TT_ASSERT at reshape invocation to error if the volumes don't match. Alternatively an OP similar to reshape could be created to purposely incorporate padding data into the real data instead of using reshape as a hack.