mlfoundations / scaling

Language models scale reliably with over-training and on downstream tasks
MIT License
90 stars 4 forks source link

FSDP Mixed Precision Setting #3

Open maximilianmbeck opened 5 months ago

maximilianmbeck commented 5 months ago

Thank you for this nice paper, your new insights and the detailed Training setup description in Section 3.1.

You mention that you are using PyTorch FSDP for training. I have some additional questions regarding this. What is your FSDPMixedPrecision setting and how many number of nodes (GPUs per node) do you use for training your neural networks? Also, which GPUs do you use?

Thanks a lot.

Best, Max

sagadre commented 5 months ago

Hi @maximilianmbeck, glad you liked Section 3.1!

For FSDP, here are some recommendations for flags in open_lm. See main.py for the implementation details:

For training:

maximilianmbeck commented 5 months ago

Hi @sagadre, thanks for your recommendations and your quick response.

I suspect that you use fsdp-amp for the smaller models, since there is no inter-node communication (you use only 1 node) and fsdp-pure-bf16 (with bfloat16 gradient reductions) to reduce the communication cost for the multi-node training. Do you agree?

In addition you report, that you are using "an additive z-loss (coefficient 1e-4), which mitigates output logit norm growth instabilities".

I would have some follow up thoughts and experiments on this. During training our models we also observed a growth of the output logit norm which lead to Infs in our PPL metrics at some point later in training. Even though we observed that we could mitigate this by adding a regularizing loss that pushes down the output logits, we tried do avoid using such a loss similar to the z-loss as suggested by the PaLM paper. Instead we investigated PyTorch FSDP Mixed Precision settings, as we suspected bfloat16 to cause issues here.

We trained two Transformer like models of size 125M and 1.3B on next-token-prediction on 4, 16 and 32 nodes (see below). We trained for approx. 10k steps with the hyperparameters specified below.

Note: drd corresponds to the reduce_dtype setting of FSDPMixedPrecision

Experiment 1:

A model with 125M parameters trained on 4 Nodes and 16 Nodes and both with reduce_dtype=bfloat16. As sharding strategy we use NO_SHARD.

Brown: B24E768gbs256--s-NO_SHARD-nn-16-drd-bfloat16-sn-125M-utc-1-l-0.0003-wd-0.1-nb-24-ed-768-seed-42 Blue: B24E768gbs256--s-NO_SHARD-nn-4 -drd-bfloat16-sn-125M-utc-1-l-0.0003-wd-0.1-nb-24-ed-768-seed-0

fsdpprecision_125M_nn4_vs_nn16

Experiment 2:

A model with 1_3B parameters trained on 32 Nodes with DDP, FSDP NO_SHARD reduce_dtype=float32 and FSDP NO_SHARD reduce_dtype=bfloat16. We compare FSDP with sharding strategy NO_SHARD to DDP.

Grey: B48E2048gbs512--s-NO_SHARD -nn-32-drd-float32 -sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42 Red: B48E2048gbs512--s-DDP -nn-32-drd-bfloat16-sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42 Green: B48E2048gbs512--s-NO_SHARD -nn-32-drd-bfloat16-sn-1_3B-utc-1-l-0.0002-wd-0.1-nb-48-ed-2048-seed-42

fsdpprecision_1_3B_fsdpsettings fsdpprecision_1_3B_fsdpsettings-zoom

We observe that setting reduce_dtype=bfloat16 for training setups with more than 4 Nodes causes the output logit norm to grow. When training with FSDP, setting reduce_dtype=float32 or training with DDP (we think that DDP also reduces gradients in float32) the output logit norm did not grow. In other experiments we even observed that the growth of the output logit norm scales roughly linear with the number of nodes.

The tricky thing is that this behavior is not visible in the loss (see screenshots), so it is hard to track down this issue to FSDP Mixed Precision.

We think this is a severe issue that needs more investigation, since the reduce dtype has a major impact on training speed and one would actually use bfloat16 for higher training throughput as you did for your 7B model with 16 Nodes.

So my follow up questions are:

sagadre commented 5 months ago

Hi @maximilianmbeck this is indeed very interesting and thanks for including the plots here for this detailed investigation!

achalddave commented 5 months ago

Yeah, when we scaled the 7B to multiple nodes fsdp-pure-bf16 helped fairly significantly. I don't have results with our most recent codebase, but early on, on 2 nodes, going from reduce=float32 to reduce=bf16 increased speed from 3000 tokens/second/gpu to 3200 t/s/g on an A100. Our most recent runs are somewhere over 4k t/s/g with openlm, not sure exactly what the impact for fsdp-pure is with the latest code. Hope that helps!

maximilianmbeck commented 5 months ago

Hi @sagadre and @achalddave,

ahh too bad, I would have been really interested in the logit norm curve when using vs. not using the z-loss in your setup. But thanks anyway for your help.

Using bfloat16 as reduce dtype really make sense when optimizing for speed. In terms of bytes to transfer, the communication cost should be halved when switching from float32 to bfloat16.

However, what puzzles me also is, why OLMo reduce their gradients in float32. See their section 3.1 (It's also very nice ;)). There must be a reason for trading off this speedup and I would like to understand why. Did they make the same observation in logit norm growth? (Maybe I'll open an issue there, too).

In any case thanks a lot. I really like your paper. Cheers, Max

sagadre commented 5 months ago

Hi Max,

We are a bit tight on compute right now, but this is super interesting. When some compute opens up, I am happy to make the plots you are suggesting! Would be interesting to know if this pops up in different experimental setups.

As for OLMo, not completely sure. Their paper does not mention z-loss, but their codebase seems to support it. It is possible they did not try bf16 reduction + z-loss?

We also were reducing in fp32 for a long time until @achalddave ran speed ablations. By the time we switched to bf16 reduction, we were already using z-loss with default coefficient 1e-4.

Thanks again for the issue and the plots! Will leave the issue open for now!

Best, Samir

maximilianmbeck commented 5 months ago

Hi Samir,

yes, indeed that would be really interesting!

Not sure, but they very likely tried out z-loss: https://github.com/allenai/OLMo/issues/361 So far no one reacted to my issue in their repo: https://github.com/allenai/OLMo/issues/514

Let me know, if I can help with anything or if you need further details. Looking forward to your plots.

Best, Max