tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
427 stars 57 forks source link

TTNN Matmul op doesn't broadcast batch dim of input a #12834

Closed sdjordjevicTT closed 1 week ago

sdjordjevicTT commented 2 weeks ago

Describe the bug TTNN Matmul op does not work when the batch dim of input_tensor_a needs to be broadcasted. According to the public documentation, this kind of product is supported. I am consistently encountering an error when trying to execute this specific operation.

To Reproduce Steps to reproduce the behavior:

import ttnn
import torch

device = ttnn.open_device(0)

tensor_a = torch.rand(1, 128, 2048)
tensor_b = torch.rand(7, 2048, 256)

tensor_c = torch.matmul(tensor_a, tensor_b)

print(tensor_c.shape)

a_tt = ttnn.from_torch(tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
b_tt = ttnn.from_torch(tensor_b, layout=ttnn.TILE_LAYOUT, device=device)

c_tt = ttnn.matmul(a_tt, b_tt)

ttnn.close_device(device)

Error once the above code is executed:

Always | FATAL    | bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent
Traceback (most recent call last):
  File "/localdev/sdjordjevic/src/tt-metal/simple_test.py", line 16, in <module>
    c_tt = ttnn.matmul(a_tt, b_tt)
  File "/localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/decorators.py", line 326, in __call__
    return self.function(*function_args, **function_kwargs)
RuntimeError: TT_FATAL @ ../ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp:975: a_shape[i] == b_shape[i]
info:
bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent
backtrace:
 --- /localdev/sdjordjevic/src/tt-metal/ttnn/ttnn/_ttnn.so(+0x36f8a8) [0x7f086a4d88a8]
 --- ttnn::operations::matmul::Matmul::validate(std::__1::vector<tt::tt_metal::Tensor, std::__1::allocator<tt::tt_metal::Tensor>> const&, std::__1::vector<std::__1::optional<tt::tt_metal::Tensor const>, std::__1::allocator<std::__1::optional<tt::tt_metal::Tensor const>>> const&) const

Expected behavior As stated in the public docs, in this particular case, the a input should be broadcasted from (1, 128, 2048) to (7, 128, 2048) and the matrix product executed.

Screenshots N/A

Please complete the following environment information:

Additional context N/A

bbradelTT commented 2 weeks ago

Only input_tensor_b can be broadcasted. That leads to some interesting corner cases (which the documentation processing has cut off, although that will need to be fixed separately), as well as not allowing input_tensor_b to be broadcasted.

I will need to update the documentation.

bbradelTT commented 2 weeks ago

Re: which the documentation processing has cut off, although that will need to be fixed separately Turns out that https://docs.tenstorrent.com/ttnn/latest/ttnn/ttnn/matmul.html is really out of date and does not reflect existing documentation.

bbradelTT commented 1 week ago

Updated doc string with this scenario.

Doc string updated via PR https://github.com/tenstorrent/tt-metal/pull/13071