tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
435 stars 63 forks source link

[Bug Report] Ttnn::multiply does not broadcast in second dimension #13646

Open dmakoviichuk-tt opened 1 week ago

dmakoviichuk-tt commented 1 week ago

Describe the bug Ttnn::multiply do not broadcast in second dimension. Fails silently by throwing near NaN/Inf valuesTo Reproduce Steps to reproduce the behavior:

with ttnn.manage_device(device_id=0) as device:
    x = torch.ones((16, 16, 32, 32), dtype=torch.bfloat16)
    y = torch.ones((16, 1, 32, 32), dtype=torch.bfloat16) * 0.5

    x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

    x_y_mult_tt = ttnn.multiply(x_tt, y_tt)
    print(ttnn.to_torch(x_y_mult_tt))

It will print a garbage. Expected behavior it should work.

Additional context Currently it requires us to generate broadcasted vectors on cpu. It is a huge problem for a training performance. @eyonland if you don't think you can fix it please let me know I'll find another owner.

eyonland commented 1 week ago

Are you able to use ttnn.bcast here instead as a workaround?

umadevimcw commented 1 week ago

@dmakoviichuk-tt @eyonland https://github.com/tenstorrent/tt-metal/pull/13673

The above PR holds functionality support for broadcasting the second dimension using repeat. This will affect the performance as we are using the repeat op but we can use it now for functionality and update this later once broadcast support is updated