Closed YJHMITWEB closed 1 week ago
Thank you @YJHMITWEB! The bug was also pointed by @chengming-zhang. This bug seems introduced when we switched from multiple all_to_all calls to all_to_all_single. @chengming-zhang Can you help us validate that this fix works?
Currently in the implementation of DeepSpeed sequence parallel, two
All_to_all
are used in the distributed attention to scatter and gather sequence. However, thereshape
operation is wrong in the secondAll_to_all
. The model will never converge as the data is corrupted by it.To easily check the problem with current implementation, we can do the following to this line:
Remove the attention computation, leaving only the
all_to_all
. And we just check thequery
before and afterall_to_all
, which should be the same._In current implementation,
all_to_all
totally messes up the data. The printed values in query are misaligned with output_The problem is because of this incorrect reshape:
When performing the second
all_to_all
, the output we gathered from other ranks is of shape:At this step, we actually want to gather all the heads of the local sequence, therefore, the above line needs to be:
Only by doing this, can we then:
which then arranges the data correctly.
A more straight-forward example is:
For the first
all_to_all
, things work fine. This issue only exists in the secondall_to_all
.