Open universome opened 2 months ago
@universome, thanks for the extensive triaging and for sharing the details.
Since sharing repro is impossible, could you try to simplify the ds_config in order to achieve parity for ZeRO-1 and ZeRO-2? In particular, could do the following steps:
data_types: {grad_accum_dtype: fp32}
so that grad accum dtype is bf16. The logic for switching dtype is complex enough to be suspect. Do this for both ZeRO-1 and ZeRO-2, in other words we obtain a new ZeRO-1 baseline curve. overlap_comm: False
- Disable backward compute and gradient reduction overlapping.use_multi_rank_bucket_allreduce: False
- Disable complicated but performant reduction logic.reduce_bucket_size
to model size (or 2X model size) to avoid overlapping reduction with backward. This will further mimic ZeRO-1 behavior, especially if there is no gradient accumulation, by performing one reduction operation at the end of backward. Please share whether the above steps are feasible. Thanks!
Hi @tjruwase, thank you a lot for your help and the pointers! We've been exploring some of these parameters before (i.e., overlap_comm, fp32 grad accumulation, huge reduce_bucket_size), I've now launched the following ablations:
overlap_comm: false
, use_multi_rank_bucket_allreduce: false
(I was not aware of this flag, it seems to be missing from the doc), 5x larger reduce_bucket_size
, the print statement in IPG reduction.What I observe for now, is that the condition self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size
here is never satisfied, since the reduce_bucket_size
is so large. I guess it's expected behavior and the if branch should never be triggered there?
Also, in your email you said:
Your observed impact of reduce_bucket_size makes sense because it triggers reduction operations which overlap with the backward computation when the generated gradient count hits a threshold.
Do I get it right that this is due to some race conditionings? I.e., some gradients are being reduced before they are fully ready?
UPD:
I've now relaunched Zero2 with the above changes, but also bf32 grads, and also trying v0.15 of DeepSpeed.
What I observe for now, is that the condition
self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size
here is never satisfied, since thereduce_bucket_size
is so large. I guess it's expected behavior and the if branch should never be triggered there?
Yes, this is expected behavior so that reduction is delayed until the end of backward after all local gradients are created. This will further mimics ZeRO-1 behavior.
Your observed impact of reduce_bucket_size makes sense because it triggers reduction operations which overlap with the backward computation when the generated gradient count hits a threshold.
Do I get it right that this is due to some race conditionings? I.e., some gradients are being reduced before they are fully ready?
Yes, race condition is my top suspect. This is because the main gradient operations in ZeRO-2 are:
I am concerned that 1, 2, and 3 are not correctly synchronized, or that our synchronization assumptions are not adequate for your workload. It is great that ZeRO-1 seems to work as that allows us to incrementally debug the ZeRO-2 extensions.
I see, thank you! Since I'm currently not using overlapping (i.e. item 2 is out of equation) and the reduce_bucket_size
is 5x model size, could you please tell how backward and partitioning of reduced gradients (in just a single bucket?) can face any race hazards?
I might then be able to modify some code blocks in the optimizer to ablate some logic.
Regarding the past two experiments (Stage-2 + overlap_comm: false
+ data_types: {grad_accum_dtype: fp32}
+ use_multi_rank_bucket_allreduce: false
+ reduce_bucket_size: <5x model size>
for versions 0.14.12 and 0.15) — they also didn't converge well (purple is stage1, green is stage2-v0.15, brown is stage2-v0.14.12):
could you please tell how backward and partitioning of reduced gradients (in just a single bucket?) can face any race hazards?
Partitioning happens here with assumption that reduction is completed https://github.com/microsoft/DeepSpeed/blob/9bc4cd01b7eb6039bb2c02b63b4b37720733dee3/deepspeed/runtime/zero/stage_1_and_2.py#L1399-1400
I might then be able to modify some code blocks in the optimizer to ablate some logic.
In terms of optimizer ablation, one idea would be to disable optimizer using learning rate of 0. Assuming the curve mismatches remain, that would significantly reduce the code surface for investigation.
UPD:
- Zero1 with bf16 gradients perform similarly to Zero2
Sorry, it seemed I missed this. Does this mean that for ZeRO-1, bf16 grad != fp32 grad?
Does this mean that for ZeRO-1, bf16 grad != fp32 grad?
Yes, exactly, for ZeRO-1, bf16 grad != fp32 grad, and it performs poorly with bf16 grad (on par with Zero-2). Does it decrease the probability that the problem is due to race conditions in Zero-2?
In terms of optimizer ablation, one idea would be to disable optimizer using learning rate of 0. Assuming the curve mismatches remain, that would significantly reduce the code surface for investigation.
Should I disable the weight decay as well? If so, the weights won't be getting updated at all, and wouldn't the loss curve become "flat" (up to variation in the training data).
Yes, exactly, for ZeRO-1, bf16 grad != fp32 grad, and it performs poorly with bf16 grad (on par with Zero-2).
This is very strange. Can you confirm that communication dtype (i.e., reduction dtype) is bf16 in both cases?
Does it decrease the probability that the problem is due to race conditions in Zero-2?
Not likely, since you also observed that ZeRO-1+fp32 grad accum != ZeRO-2+fp32 grad accum.
Should I disable the weight decay as well? If so, the weights won't be getting updated at all, and wouldn't the loss curve become "flat" (up to variation in the training data).
Yes, also disable weight decay so weights not updated at all. Yes, this should expectedly provide similar loss curves for the four variations of ZeRO-1/-2 and grad accum=bf16/fp32. Additionally, please print the global gradient norms by calling engine.get_global_grad_norm()
https://github.com/microsoft/DeepSpeed/blob/4864991f53bd2e12446198bcc655f919eb9157f9/deepspeed/runtime/engine.py#L492
These runs should be short, only up to the point of divergence (e.g., ~250K steps in your plot). Hope that makes sense.
This is very strange. Can you confirm that communication dtype (i.e., reduction dtype) is bf16 in both cases?
Yes, in the config, we have communication_data_type: bf16
and in the logs in both cases, it prints:
[2024-08-17 06:04:18,674] [INFO] [config.py:1000:print] communication_data_type ...... torch.bfloat16
(but also, there is [2024-08-17 06:04:18,677] [INFO] [config.py:1000:print] seq_parallel_communication_data_type torch.float32
for both runs, but we are not using sequence parallel training). To enable bf16 grad accum dtype, I just commented out data_types: {grad_accum_dtype: fp32}
in the config. I also confirmed that it got set as expected by printing <ds_engine>.get_data_types()
and it was returning (torch.bfloat16, torch.bfloat16)
instead of (torch.bfloat16, torch.float32)
This should expectedly provide similar loss curves for the four variations of ZeRO-1/-2 and grad accum=bf16/fp32.
I see, I will launch them today, thank you.
Hi @tjruwase , we ran those 4 experiments and found that the gradient norm for Zero1-grad-acc-fp32 is slightly different from Zero1-grad-acc-bf16, Zero2-grad-acc-fp32 and Zero2-grad-acc-bf16 (which are all similar between each other). Here are the plots attached: global gradient norm, global gradient norm (smoothed with wandb), and loss curve:
Hello @universome, I am chengming zhang, the collaborator with Olatunji (@tjruwase). I am helping to solve this issue. To reproduce this issue, I created an example diffusion model using diffusers (https://github.com/zhangsmallshark/ds-debug) (this code is old version, I will update the code when my remote server is back in 2 days, so you may also run the code). I am using the same configurations as you. As shown in the result figure, my observation: in the current model zero stage and reduce_bucket_size will not affect the loss.
The digit after zero is the zero stage. 10k is the dataset size. The number after bucket is reduce_bucket_size.
Possible solutions: 1. your team may use a different model structure from me. I am using UNet2DModel. If possible could you point out which model in diffusers is similar to your model? so I will modify my example and try to reproduce this issue.
@universome Hello, this is just the follow-up. I already updated the example code, and you can directly run the code via "bash run_train.sh". Please let me know the update when you are free.
Hi @zhangsmallshark, I really apologize for replying so late. The thing is that we've been working on exploring FSDP2 + tensor parallel for our pipeline, and it now seems to be working for us reasonably at the moment (it has some weirdly surged gradient norms, but we do not observe any loss divergence — at least currently). I will try to arrange the above ablations as soon as I have the availability.
Describe the bug I launch deepspeed training for a 600M parameter diffusion model, and only vary
reduce_bucket_size
. I tried the following values:reduce_bucket_size: 500_000_000
— converges poorlyreduce_bucket_size: 1_000_000_000
— converges sllightly better in the beginning, but then still worse than Zero Stage 1.reduce_bucket_size: 10_000_000
— almost does not converge at all, the losses are several times higher.reduce_bucket_size: 1_000_000
— I start getting NaNs for my loss values almost immediately during training.I use deepspeed 0.14.2 (upgrading to 0.14.5 didn't help). The rest of my config looks like this:
To Reproduce As per now, I cannot provide a simple reproducible example since it's deep in the internal codebase. I want to ask you where can I look at to isolate/locate an issue.
Expected behavior
reduce_bucket_size
shouldn't influence training at all.Screenshots
Here is the screenshot with my training losses:
reduce_bucket_size: 500_000_000
(default)reduce_bucket_size: 1_000_000_000
reduce_bucket_size: 10_000_000
reduce_bucket_size: 1_000_000
System info (please complete the following information):
Launcher context I launch via torch.distributed.run
Docker context It's an internal one, cannot share it.
Additional context I use activations checkpointing with
use_reentrant=True
(use_reentrant=False
makes the model converge worse), mixed precision training. For debugging purposes, I tried to do backward withallreduce_gradients=False
and then doingoptimizer.reduce_gradients(pipeline_parallel=False)
manually (to follow Zero Stage 1), but it didn't help.