Open SeanNijjar opened 6 months ago
Marking P2 as this config currently isn't used (or expected to be used) in any of our priority models
I'm going to bump this up to P1 since it's affecting Mixtral perf. Although it's possible for the mixtral demo to be implemented with
ag_out = all_gather(dim=2)
r_out = reshape([1,1,z*y,x] -> [1,z,y,x], ag_out)
reduce(dim=1, r_out)
it's not the ideal design and also depends on the reshape doing the right thing and being implemented as a reinterpet in this case, which it might not be.
Ideally, we just have all_gather(dim=1), reduce()
, less operations, less opportunities for mess up, cleaner demo code, etc.
Took a quick look at this since I found and fixed a bug in the all-gather host code while bringing up linear all-gather. It looks like some of the mismatches were resolved but some still remain. I think there are (now were, on branch), multiple issues that exposed themselves in these tests.
Based on some limited data, and a mismatch dump, I have a suspicion that this bug isn't related to dim0/1 specifically but may be a general problem with all gathers on non-width dim where the tensor size is > 1 on any outer dims of the all gather dim. It also so far seems like the remaining mismatch might be specific to tile interleaved layout.
For example, if I have a canonical (output) shape of [8,4,256,384], and I do an all gather on dim=2 or dim=3, then I think this issue may arise.
At the moment, tests pass because all-gather falls back to unidirectional for dim = 0 or dim = 1. To re-enable failing tests, enable bidirection all-gather for all dims, in
all_gather_op.hpp
. All cases should be producing correct output with bidirectional enabledNote that this issue is still sometimes present. In allgather_op.hpp, we conditionally disable bidirectional support when the dims outside of the concat dim are > 1. We need to lift this restriction.