tdrussell / qlora-pipe

A pipeline parallel training script for LLMs.
MIT License
79 stars 5 forks source link

How to correctly set bf16 for training, and seems deepspeed config additional offload/bf16 conflicts with the qlora-pipe setting? #16

Open iamhappytoo opened 2 months ago

iamhappytoo commented 2 months ago

Hello @tdrussell,

First of all, thank you very much for your great repo! It is absolutely great work to pull all these optimization solutions together. When I use the repo, I try to use the bf16 enabled in deepspeed config to support full parameter finetune of a 70b model. However it seems giving me oom even larger than without it. With the default setting (not setting bf16: true in deepspeed config, and set every possible options in .toml config to bfloat16) the evaluation phase still seems to be using the float32, with 70b models consuming 33G * 8 memory when doing evaluating, and oom when the training come.

I'm wondering if the deepspeed config is by default deprecated from the qlora-pipe so one should not use it? And how should I correctly set the bf16 in the code? Thank you so much in advance!

tdrussell commented 2 months ago

I also have been experimenting with full fine tuning of an 8b model. I just pushed some commits which add the adamw_kahan optimizer type. There's a section in the README that now discusses floating point precision as well. For FFT I recommend setting everything to bf16 and using adamw_kahan optimizer.

It is expected that Deepspeed's bf16 mode uses more memory. I think it wraps the optimizer, and does master weights + gradient accumulation + optimizer states all in fp32. This will use much more memory than full bf16 + Kahan summation in the optimizer. I would not use Deepspeed's bf16 mode unless you have a very large amount of VRAM to spare.

If you are setting model_weight_dtype to bf16, it should not be loading the model in fp32. Can you call the print_model_info() function on the pipeline_model after it is loaded? It will show you the dtype of the model weights. If they really are in fp32 despite setting bf16 in the config, there is some bug or edge case somewhere. One more thing to check: are the model weights on disk fp32? Perhaps it is somehow ignoring the config and just loading it as the dtype the model is stored in.

iamhappytoo commented 2 months ago

Hi @tdrussell, Really appreciate the helpful clarifications and explanations you provide! The updated repo also helps a lot to makes things clearer. I used the latest repo, call the print_model_info() and I can confirm the model is loaded in bf16 now:) Thank you! But for 70b model, the 8 x 80G VRAM still gives oom. I found that some people were able to get full param tuning of 70b model work with 960 G or 1120G using deepspeed, do you have an estimate about how much VRAM would qlora-pipe take to full param tuning the 70b model, with all possible optimizations applied? And I guess qlora-pipe can be directly migrated to support multi-node training, through modifying the deepspeed launch script, so long as the cluster environment is set up properly, is my understanding correct? Thank you so much!

tdrussell commented 2 months ago

I think it is expected to OOM on 8x80 VRAM. In my understanding, this is how much memory we need per parameter:

  1. the model weights
  2. 1st Adam moment
  3. 2nd Adam moment
  4. Kahan summation buffer
  5. Accumulators for gradients

Each of these is bf16, so we have 10 bytes per parameter of fixed state to do FFT using this setup. The 700 GB required already OOMs, and you still need a bit for activations. 960 GB should to be enough though. Note that traditional mixed precision would use even more, because it keeps a fp32 master copy of the weights also.

There are some changes to the optimizer you could make to try to lower VRAM a bit more.

  1. The optimi library I use for Kahan summation AdamW also has Lion optimizer. This uses less state, and will bring you down to 8 bytes per parameter.
  2. optimi also supports the optimizer accumulation + gradient release technique, but I have not tried it and it will need code changes. This removes the need for the gradient accumulators so saves an additional 2 bytes per parameter.
  3. You could try Galore optimizer, which will save a very large amount of VRAM, but this is not exactly equivalent to FFT.

Regarding multi-node training, I have developed everything with the assumption that it is only 1 node, and it is untested on more than 1. But in principle, I think it should work, assuming Deepspeed pipeline parallelism supports it. I would not be surprised if there are some places in the code that would break, and need fixes.

EDIT: One thing worth trying first, just to check. You can use normal AdamW optimizer instead of kahan AdamW. This immediately saves 2 bytes per parameter. There will almost certainly be significant quality degradation in terms of the model learning the dataset, so in practice it is not usable, but it instantly saves 2 bytes per parameter. Now it should use 560 GB total for state, so you can check that it indeed fits on the 8x80GB machine.

iamhappytoo commented 2 months ago

Hi @tdrussell, thank you so much for your detailed answer! I tried Lion and normal AdamW optimizer, but both still gave oom. Then I moved on to try the multi-node training with deepspeed (by changing --num_gpus xx to --hostfile hostfile, and configured the ssh connection to successfully run pdsh between two multi-GPU nodes). When running the code in multi-node mode, the deepspeed.init_distributed() can run successfully, but the zero_first() in dataset_utils.py starts to give error. I tried to reproduce the error with simplified codes, and it converges to the situation that calls that require the inter-node communication like barrier() or torch.distributed.all_reduce() will throw the same errors. The error is directly killing the subprocesses with return code = -11, and followed by ssh exited with exit code 245. I am not sure if this is related to accelerate usage in multi-node environment. Do you have some thoughts/suggestions about this error? Any thoughts from you are much appreciated!