tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
417 stars 54 forks source link

[Bug Report] Multiply and divide has different broadcasting rules #12798

Open rfurko-tt opened 2 weeks ago

rfurko-tt commented 2 weeks ago

Describe the bug I can't use ttnn.divide same way as ttnn.multiply. Multiply works as expected, divide crashes.

To Reproduce

import ttnn
import torch
import numpy as np

with ttnn.manage_device(device_id=0) as device:
    x = torch.ones((1, 2, 3, 4), dtype=torch.bfloat16)
    y = torch.ones((1, 1, 1, 1), dtype=torch.bfloat16) * 0.5

    x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
    y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

    x_y_mult_tt = ttnn.multiply(x_tt, y_tt)
    print(ttnn.to_torch(x_y_mult_tt))

    x_y_divide_tt = ttnn.divide(x_tt, y_tt)
    print(ttnn.to_torch(x_y_divide_tt))

Expected behavior Divide should have same broadcasting rules to multiply.

Please complete the following environment information:

dmakoviichuk-tt commented 2 weeks ago

@eyonland I think you are owner of the unary/binary ops. Feel free to move to the right person.

dmakoviichuk-tt commented 2 weeks ago

HI @yugaoTT I've noticed that DIV is implemented as RECIP and MUL ops inside. Looks like the case where broadcast should work was accidentally missed and seems easy to support it as everything is ready.

umadevimcw commented 1 week ago

@dmakoviichuk-tt @eyonland @rfurko-tt Broadcast is not supported for Div OP

image

Although division (Div) is a combination of Recip and MUL, broadcasting was not occurring in the program factory during validation because the Div operation was treated as a separate operation. As shown in the image above, the bcast only checks for ADD, SUB, and MUL operations.

To enable broadcasting, I introduced DIV_FAST (as shown in the image below), which successfully enabled broadcasting. However, this altered the operation's definition, and the division itself is no longer happening as intended.

image

Repeat op Approach:

Finally, I combined the ones_like and multiply operations to achieve the desired broadcast. This approach produced the correct output for the test mentioned above (see the changes below for the differences).

Please find the PR here : #12932

image

If the changes are acceptable I can generalize It for both the inputs

eyonland commented 1 week ago

@umadevimcw, it looks like your solution would generalize for both inputs automatically and possibly solve this as a general case on eltwise ops. I think the return on this is high and we quickly get to better coverage of this functionality.

dmakoviichuk-tt commented 1 week ago

@umadevimcw changes are not acceptable. It should be solved in the normal way. Yo rub one simple DIV op we added a few more ops. @eyonland It is a really bad solution for a couple of reasons: 1) we are adding specific changes to the generic binary function 2) we are making division slower. If we want to make it slower we can always call reciprocal and than multiplication like what we are doing right now as a workaround. 3) Currently DIV implemented as a chain of ops RECIP + MUL in one kernel. But looks like bcast doesn't support this chaining. It supports only MUL. So he right solution is to make sure that bcast can support chaining or add changes to the bcast kernel and add bcast_div* functions where we do chaining.

So if you take a look here: https://github.com/tenstorrent/tt-metal/blob/8e0e2e1994a67fa4b179c18b54565685e5eba04a/ttnn/cpp/ttnn/operations/data_movement/bcast/bcast_types.cpp#L13 Imaging you added BCastMathOp::DIV you need to add a few new kernel functions similar to "mul_tiles_bcast" which will do the same chaining as in regular binary op.

rfurko-tt commented 1 week ago

@umadevimcw Thanks for providing detailed explanation. Could you please provide more information:

  1. why ttnn::divide is executed as two operations? We are often work in bfloat16 which suffers from lack of precision and I would expect every operation takes a little bit of precision from the result.
  2. why don't we re-use multiply broadcasting mechanism inside DIV_FAST? (In proposed solution we use redundant multiply to use it implicitly)
  3. I think it's fair to have bunch of workarounds in the model (temporarily), but we shouldn't have hacks on top of hacks in ops development. What's our current position on this?
  4. From performance perspective, when we perform divide we perform 3 operations: multiply, multiply, reciprocal (which is divide). I think it's unreasonable and be a hidden source of development slowdown. What's our policy here for adding overhead?

Thanks in advance!

eyonland commented 1 week ago

@rfurko-tt and @dmakoviichuk-tt , if I understand this problem correctly, we cannot do broadcasting within the kernel operation itself because there is no clear way to identify the logical shape of the tiled tensor in the width and height dimensions. We either need to allow this workaround to unblock models or we wait for the tensor layout and shape class to be rewritten and at that time we come back to this and fix the kernels to properly handle the broadcasting. My position is that we allow this for now, knowing it is not an optimal solution, but then circle back and implement this properly in the kernel itself after the tensor layout is fixed. @umadevimcw , please open a separate issue to track the work of implementing broadcasting in the kernel itself. This needs to be done regardless of the outcome of this issue.

dmakoviichuk-tt commented 6 days ago

@eyonland You understood the problem incorrectly. If broadcasting works in multiply it must work in div operation. There are no excuses. In my opinion this hot fix doesn't make any sense. Because it makes it work even slower than obvious workaround with 2 manuall calls. We should never allow "fixes" like this. In order to unblock the issue we can simply call reciprocal and multiply manually. It is a bug/issue in the binary op implementation where in one case div is treated ad chain of RECIP and MUL but when it goes broadcast way there is no support of it in both TTNN and kernel side. I expected you as an owner to drive this issue because it will require proper fixes in TTNN binary op.

mrakitaTT commented 7 hours ago

Hi folks, we are hitting this issue in tt-mlir compiler so I just wanted to check with you what is the status of this issue and do you maybe have some estimate how long it will take to fix it?