tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
396 stars 49 forks source link

All-Gather sometimes produces incorrect output for dim=0/1 cases #6448

Open SeanNijjar opened 6 months ago

SeanNijjar commented 6 months ago

At the moment, tests pass because all-gather falls back to unidirectional for dim = 0 or dim = 1. To re-enable failing tests, enable bidirection all-gather for all dims, in all_gather_op.hpp. All cases should be producing correct output with bidirectional enabled

Note that this issue is still sometimes present. In allgather_op.hpp, we conditionally disable bidirectional support when the dims outside of the concat dim are > 1. We need to lift this restriction.

SeanNijjar commented 6 months ago

Marking P2 as this config currently isn't used (or expected to be used) in any of our priority models

SeanNijjar commented 5 months ago

I'm going to bump this up to P1 since it's affecting Mixtral perf. Although it's possible for the mixtral demo to be implemented with

ag_out = all_gather(dim=2)
r_out = reshape([1,1,z*y,x] -> [1,z,y,x], ag_out)
reduce(dim=1, r_out)

it's not the ideal design and also depends on the reshape doing the right thing and being implemented as a reinterpet in this case, which it might not be.

Ideally, we just have all_gather(dim=1), reduce(), less operations, less opportunities for mess up, cleaner demo code, etc.

SeanNijjar commented 5 months ago

Took a quick look at this since I found and fixed a bug in the all-gather host code while bringing up linear all-gather. It looks like some of the mismatches were resolved but some still remain. I think there are (now were, on branch), multiple issues that exposed themselves in these tests.

Based on some limited data, and a mismatch dump, I have a suspicion that this bug isn't related to dim0/1 specifically but may be a general problem with all gathers on non-width dim where the tensor size is > 1 on any outer dims of the all gather dim. It also so far seems like the remaining mismatch might be specific to tile interleaved layout.

For example, if I have a canonical (output) shape of [8,4,256,384], and I do an all gather on dim=2 or dim=3, then I think this issue may arise.