microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.17k stars 4.07k forks source link

[BUG] Sequence Parallel(Ulysses) Training Gradient Scaling Issue #5248

Open tkdcjf159 opened 7 months ago

tkdcjf159 commented 7 months ago

When training a language model (LM) with DeepSpeed's Sequence Parallel (Ulysses), it's typical to get a cross-entropy loss for each rank. To compute the gradients accurately, as I understand it, an in-place division by tensor_toallreduce.div followed by an all-reduce operation is necessary.


process_group = self.dp_process_group if process_group is None else process_group
..
tensor_to_allreduce.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))

Without performing tensor_to_allreduce.div_, the gradient would be scaled by the sequence parallel size, resulting in much higher gradients than expected. In an effort to address this, I've reflected this change in this commit, but looking at the current main code, it seems like the divide is always set to false from this commit, so it appears tensor_toallreduce.div might not be applied correctly.

Alternatively, would it be acceptable to just divide the loss by SEQUENCE_PARALLEL_WORLD_SIZE and then perform a backward operation?

RezaYazdaniAminabadi commented 7 months ago

Hi @tkdcjf159,

The gradients get scaled correctly as the division is happening before getting to this function:

  1. non-MoE parameters: https://github.com/microsoft/DeepSpeed/blob/535a908f1b60f819df4ccf1071f7c917c39dabbe/deepspeed/runtime/zero/stage_1_and_2.py#L1107
  2. MoE parameters: https://github.com/microsoft/DeepSpeed/blob/535a908f1b60f819df4ccf1071f7c917c39dabbe/deepspeed/runtime/zero/stage_1_and_2.py#L1067

So, there is no need to do the divison in the place you mentioned. I think we should remove this flag at some point to resolve this confusion. Thanks, Reza

tkdcjf159 commented 7 months ago

@RezaYazdaniAminabadi Thank you for letting me know. I'm not sure if this is related to the issue, but when the reduce_bucket_size and allgather_bucket_size are small (e.g. hidden_size * hidden_size) in zero_optimization configuration, in zero2, there have been cases where the loss becomes NaN when applying sequence parallel. I'm curious if you happen to know the reason for this as well.

Kwen-Chen commented 5 months ago

When training a language model (LM) with DeepSpeed's Sequence Parallel (Ulysses), it's typical to get a cross-entropy loss for each rank. To compute the gradients accurately, as I understand it, an in-place division by tensor_toallreduce.div followed by an all-reduce operation is necessary.

process_group = self.dp_process_group if process_group is None else process_group
..
tensor_to_allreduce.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))

Without performing tensor_to_allreduce.div_, the gradient would be scaled by the sequence parallel size, resulting in much higher gradients than expected. In an effort to address this, I've reflected this change in this commit, but looking at the �current main code, it seems like the divide is always set to false from this commit, so it appears tensor_toallreduce.div might not be applied correctly.

Alternatively, would it be acceptable to just divide the loss by SEQUENCE_PARALLEL_WORLD_SIZE and then perform a backward operation?

  • example code
for step, batch in enumerate(dataloader):            
    if ENABLE_DS_SEQUENCE_PARALLEL:    
        # get sub-sequence for sequence parallel
        seq_length = batch['input_ids'].size(1)                        
        assert seq_length % SEQUENCE_PARALLEL_WORLD_SIZE == 0
        sub_seq_length = seq_length // SEQUENCE_PARALLEL_WORLD_SIZE
        sub_seq_start = SEQUENCE_PARALLEL_RANK * sub_seq_length
        sub_seq_end = (SEQUENCE_PARALLEL_RANK + 1) * sub_seq_length            
        # move to device [B, T/SP]
        input_ids = batch['input_ids'][:, sub_seq_start:sub_seq_end].to(device)                    
        attention_mask = batch['attention_mask'][:, sub_seq_start:sub_seq_end].to(device)                    
        labels = batch['labels'][:, sub_seq_start:sub_seq_end].to(device)
        position_ids = torch.arange(seq_length).unsqueeze(0)
        position_ids = position_ids[:, sub_seq_start:sub_seq_end].to(device)           
    else:
        # move to device [B, T]
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)   
        position_ids = None

    #forward() method
    loss = model_engine(input_ids=input_ids, attention_mask=attention_mask, labels=labels, position_ids=position_ids)

    #############################################runs backpropagation    
    model_engine.backward(loss/SEQUENCE_PARALLEL_WORLD_SIZE) #Would it be acceptable to calculate the loss in this alternative way?
    ###############################################

    #weight update
    model_engine.step()

Your code has been very helpful to me. I would like to ask where you define the sequence_parallel_group. Is the information of sequence_parallel_group passed into the initialization function when initializing the model_engine? If you could provide me with a small demo for initializing the model_engine, I would be very grateful.

jinhuaca commented 2 months ago

Why do we need to do the division

tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))?

Is it to scale the loss function up according to the sequence parallel group size?

Why not multiply but divide? For example, why not do

tensor.div_(dist.get_world_size(group=self.dp_process_group) * float(self.sequence_parallel_size))