microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Apache License 2.0
33.6k stars 3.94k forks source link

Fixing the reshape bug in sequence parallel alltoall, which corrupted all QKV data #5664

Closed YJHMITWEB closed 1 week ago

YJHMITWEB commented 2 weeks ago

Currently in the implementation of DeepSpeed sequence parallel, two All_to_all are used in the distributed attention to scatter and gather sequence. However, the reshape operation is wrong in the second All_to_all. The model will never converge as the data is corrupted by it.

To easily check the problem with current implementation, we can do the following to this line:

def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
      """ forward

          query (Tensor): query input to the layer
          key (Tensor): key input to the layer
          value (Tensor): value input to the layer
          args: other args

          * output (Tensor): context output
      # TODO Merge three alltoall calls into one
      # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
      #in shape : e.g.,  [s/p:h:]
      query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
      key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
      value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)

      #out shape : e.g., [s:h/p:]
      context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)

      output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)

      #out e.g., [s/p::h]
      return output

Remove the attention computation, leaving only the all_to_all. And we just check the query before and after all_to_all, which should be the same.

query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)

#out shape : e.g., [s:h/p:]
#context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) # do not perform attn,
context_layer = query_layer                                                            # just use the input query

output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)

if torch.distributed.get_rank() == 3:

_In current implementation, all_to_all totally messes up the data. The printed values in query are misaligned with output_

The problem is because of this incorrect reshape:

def single_all_to_all(input, scatter_idx, gather_idx, group):
    seq_world_size = dist.get_world_size(group)
    inp_shape = list(input.shape)
    inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
    if scatter_idx < 2:
        input_t = input.reshape(
            [seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
        input_t = input.reshape(
            [-1, seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).transpose(0, 1).contiguous()

    output = torch.empty_like(input_t)
    dist.all_to_all_single(output, input_t, group=group)

    # if scattering the seq-dim, transpose the heads back to the original dimension
    if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()

    return output.reshape(
        inp_shape[: gather_idx] + \
        [inp_shape[gather_idx] * seq_world_size,] + \
        inp_shape[gather_idx + 1:]).contiguous()

When performing the second all_to_all, the output we gathered from other ranks is of shape:

dist.all_to_all_single(output, input_t, group=group)
# output: [seq_world_size, batch, local_seq_length, num_local_heads, head_dim]

if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()
# output: [batch, seq_world_size, local_seq_length, num_local_heads, head_dim]

At this step, we actually want to gather all the heads of the local sequence, therefore, the above line needs to be:

if scatter_idx < 2:
        output = output.transpose(0, 2)
# output: [batch, local_seq_length, seq_world_size, num_local_heads, head_dim]

Only by doing this, can we then:

return output.reshape(
    inp_shape[: gather_idx] + \
    [inp_shape[gather_idx] * seq_world_size,] + \
    inp_shape[gather_idx + 1:]).contiguous()

which then arranges the data correctly.

A more straight-forward example is:

# second all_to_all
# batch: 1
# sequence parallel size: 4
# local sequence length: 8192
# total number of heads: 16
# head dim: 128

dist.all_to_all_single(output, input_t, group=group)
# output: [4, 1, 8192, 4, 128]

if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()
# output: [1, 4, 8192, 4, 128]
# At this step, you cannot directly reshape it into [1, 8192, 16, 128] as it corrupts the data.
# You need to permute output into [1, 8192, 4, 4, 128], then reshape it into [1, 8192, 16, 128].

For the first all_to_all, things work fine. This issue only exists in the second all_to_all.

tohtana commented 2 weeks ago

Thank you @YJHMITWEB! The bug was also pointed by @chengming-zhang. This bug seems introduced when we switched from multiple all_to_all calls to all_to_all_single. @chengming-zhang Can you help us validate that this fix works?