tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
427 stars 57 forks source link

TTNN Matmul op doesn't broadcast batch dimension if the 3rd dim isn't 1 #12833

Closed sdjordjevicTT closed 1 week ago

sdjordjevicTT commented 2 weeks ago

Describe the bug TTNN Matmul op does not work when the 3rd dimension of input_tensor_b isn't 1, hence the broadcast version of Matmul isn't executed. According to the public documentation, this kind of product is supported. I am consistently encountering an error when trying to execute this specific operation.

To Reproduce

import ttnn
import torch

device = ttnn.open_device(0)

tensor_a = torch.rand(7, 7, 128, 2048)
tensor_b = torch.rand(7, 2048, 256)

tensor_c = torch.matmul(tensor_a, tensor_b)

print(tensor_c.shape)

a_tt = ttnn.from_torch(tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
b_tt = ttnn.from_torch(tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

c_tt = ttnn.matmul(a_tt, b_tt)

ttnn.close_device(device)

Error once the above code is executed:

Always | FATAL    | Error
Traceback (most recent call last):
  File "/localdev/sdjordjevic/src/tt-metal/simple_test.py", line 16, in <module>
    c_tt = ttnn.matmul(a_tt, b_tt)
  File "/localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/decorators.py", line 326, in __call__
    return self.function(*function_args, **function_kwargs)
RuntimeError: TT_FATAL @ ../ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp:971: a_shape.rank() == b_shape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank"
info:
Error
backtrace:
 --- /localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/_ttnn.so(+0x36f8a8) [0x7fd87c4d78a8]
 --- ttnn::operations::matmul::Matmul::validate(std::__1::vector<tt::tt_metal::Tensor, std::__1::allocator<tt::tt_metal::Tensor>> const&, std::__1::vector<std::__1::optional<tt::tt_metal::Tensor const>, std::__1::allocator<std::__1::optional<tt::tt_metal::Tensor const>>> const&) const

Expected behavior As stated in the public docs, in this particular case, the b input should be broadcasted from (7, 2048, 256) to (7, 7, 2048, 256) and the matrix product executed.

Screenshots N/A

Please complete the following environment information:

Additional context N/A

bbradelTT commented 2 weeks ago

We put a check in that requires the second input to have an explicit first dimension of size 1 in this scenario.

@TT-BrianLiu do you know how much work it'd be to allow the code to work in this scenario?

If it's too much work, we would need to update the documentation.

bbradelTT commented 1 week ago

Batch dimensions don't line up regardless. Updated doc string with this scenario.

Doc string updated via PR https://github.com/tenstorrent/tt-metal/pull/13071