tenstorrent / pytorch2.0_ttnn

⭐️ TTNN Compiler for PyTorch 2.0 ⭐️ It enables running PyTorch2.0 models on Tenstorrent hardware
https://tenstorrent.github.io/tt-metal/latest/ttnn/
14 stars 5 forks source link

bert model view to ttnn.reshape issues #27

Open kevinwuTT opened 2 months ago

kevinwuTT commented 2 months ago

bert model calls aten.view with these shapes:

input_shape -> output_shape:

essentially the same as ttnn.squeeze with dim = 0, compatible with ttnn.reshape [1, 16, 256, 256] -> [16, 256, 256] [1, 16, 256, 64] -> [16, 256, 64] [1, 16, 64, 256] -> [16, 64, 256] [1, 256, 4096] -> [256, 4096] [1, 256, 1024] -> [256, 1024]

essentially the same as ttnn.unsqueeze_to_4D, compatible with ttnn.reshape [16, 256, 256] -> [1, 16, 256, 256] [16, 256, 64] -> [1, 16, 256, 64]

unsqueeze, compatible with ttnn.reshape [256, 2] -> [1, 256, 2] [256, 1024] -> [1, 256, 1024] [256, 4096] -> [1, 256, 4096]

Not compatible with ttnn.reshape [1, 256, 1024] -> [1, 256, 16, 64] [1, 256, 16, 64] -> [1, 256, 1024]

ayerofieiev-tt commented 1 month ago

@mcw-zwakeelTT can you please add example unsqueeze?

ayerofieiev-tt commented 1 month ago

No known workarounds. Might be fixed around end of July.

mcw-zwakeelTT commented 1 month ago

@ayerofieiev-tt Unsqueezing last dim of 3D tensor to make the output a 4D. This throws a runtime error "Unable to reshape given tensor". (5, 2, 4) with dim 3 should produce (5, 2, 4, 1) but throws a runtime error. There is no such examples in bert test case though.

ayerofieiev-tt commented 1 month ago

I wonder if this works

input = T[1, 256, 1024]
shape = [1, 256, 16, 64]
reshape(input, layout) 

lower to

if(input is tile and output shape is not tiled)
  output = to_layout(input, RowMajor);  // [1, 256, 1024]
  output = reshape(output, shape); // [1, 256, 16, 64]
  output = to_layout(output, Tiled); // [1, 256, 16[32], 64]

Are ops in general (like softmax) aware about padding?

mcw-zwakeelTT commented 1 month ago

reshape to 4D is executed on device and only works if last dim is the same for both input & output. In your example it will throw a runtime error (1024 != 64) Currently reshape takes place after converting input tensor to row major layout. Tile layout is only valid if the tensor is reshaped to tiled height & width. The inputs of ops are padded if it is requested to be in tile layout.