microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
33.6k stars 3.94k forks source link

Fixing the reshape bug in sequence parallel alltoall, which corrupted all QKV data #5664

Closed YJHMITWEB closed 1 week ago

YJHMITWEB commented 2 weeks ago

Currently in the implementation of DeepSpeed sequence parallel, two All_to_all are used in the distributed attention to scatter and gather sequence. However, the reshape operation is wrong in the second All_to_all. The model will never converge as the data is corrupted by it.

To easily check the problem with current implementation, we can do the following to this line:

def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
      """ forward

      Arguments:
          query (Tensor): query input to the layer
          key (Tensor): key input to the layer
          value (Tensor): value input to the layer
          args: other args

      Returns:
          * output (Tensor): context output
      """
      # TODO Merge three alltoall calls into one
      # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
      #in shape : e.g.,  [s/p:h:]
      query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
      key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
      value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)

      #out shape : e.g., [s:h/p:]
      context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)

      output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)

      #out e.g., [s/p::h]
      return output

Remove the attention computation, leaving only the all_to_all. And we just check the query before and after all_to_all, which should be the same.

query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)

#out shape : e.g., [s:h/p:]
#context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) # do not perform attn,
context_layer = query_layer                                                            # just use the input query

output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)

if torch.distributed.get_rank() == 3:
     print(query[0][15730][5])
     print(output[0][15730][5])

_In current implementation, all_to_all totally messes up the data. The printed values in query are misaligned with output_

The problem is because of this incorrect reshape:

def single_all_to_all(input, scatter_idx, gather_idx, group):
    seq_world_size = dist.get_world_size(group)
    inp_shape = list(input.shape)
    inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
    if scatter_idx < 2:
        input_t = input.reshape(
            [seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).contiguous()
    else:
        # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
        input_t = input.reshape(
            [-1, seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).transpose(0, 1).contiguous()

    output = torch.empty_like(input_t)
    dist.all_to_all_single(output, input_t, group=group)

    # if scattering the seq-dim, transpose the heads back to the original dimension
    if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()

    return output.reshape(
        inp_shape[: gather_idx] + \
        [inp_shape[gather_idx] * seq_world_size,] + \
        inp_shape[gather_idx + 1:]).contiguous()

When performing the second all_to_all, the output we gathered from other ranks is of shape:

dist.all_to_all_single(output, input_t, group=group)
# output: [seq_world_size, batch, local_seq_length, num_local_heads, head_dim]

if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()
# output: [batch, seq_world_size, local_seq_length, num_local_heads, head_dim]

At this step, we actually want to gather all the heads of the local sequence, therefore, the above line needs to be:

if scatter_idx < 2:
        output = output.transpose(0, 2)
# output: [batch, local_seq_length, seq_world_size, num_local_heads, head_dim]

Only by doing this, can we then:

return output.reshape(
    inp_shape[: gather_idx] + \
    [inp_shape[gather_idx] * seq_world_size,] + \
    inp_shape[gather_idx + 1:]).contiguous()

which then arranges the data correctly.

A more straight-forward example is:

# second all_to_all
# batch: 1
# sequence parallel size: 4
# local sequence length: 8192
# total number of heads: 16
# head dim: 128

dist.all_to_all_single(output, input_t, group=group)
# output: [4, 1, 8192, 4, 128]

if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()
# output: [1, 4, 8192, 4, 128]
# At this step, you cannot directly reshape it into [1, 8192, 16, 128] as it corrupts the data.
# You need to permute output into [1, 8192, 4, 4, 128], then reshape it into [1, 8192, 16, 128].

For the first all_to_all, things work fine. This issue only exists in the second all_to_all.

tohtana commented 2 weeks ago

Thank you @YJHMITWEB! The bug was also pointed by @chengming-zhang. This bug seems introduced when we switched from multiple all_to_all calls to all_to_all_single. @chengming-zhang Can you help us validate that this fix works?