Closed stas00 closed 3 years ago
OK, @samyam helped me to figure out ZeRO-3 - getting a 3.5x larger BS than with zero2. The key was to lower:
"sub_group_size": 1e9,
from
1e14
.So, I'm able to train t5-11b on a single A100-SXM4-40GB with seq len 1024 with BS=14 with deepspeed ZeRO-3:
export BS=14; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 deepspeed --num_gpus=1 \ examples/pytorch/translation/run_translation.py --model_name_or_path t5-11b --output_dir output_dir \ --adam_eps 1e-06 --evaluation_strategy=steps --do_train --label_smoothing 0.1 --learning_rate 3e-5 \ --logging_first_step --logging_steps 500 --max_source_length 1024 --max_target_length 1024 \ --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS \ --predict_with_generate --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 \ --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --val_max_target_length \ 128 --warmup_steps 50 --max_train_samples 2000 --max_eval_samples 50 --deepspeed \ tests/deepspeed/ds_config_zero3.json --fp16
everything else is the same as in the zero-2 post above, and config file is too from transformers @ 61c5063 , but
ds_config_zero3.json
needs to be changed as shown above.
@stas00 could you confirm your torch / deepspeed / apex / transformers versions
@stas00 Thanks so much May I also ask why you used LR = 3e-5 when HF page itself notes
T5 models need a slightly higher learning rate than the default one set in the Trainer when using the AdamW optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
I used LR = 1e-3 previously without deep speed and it worked perfectly. I am doing generation, but now when using deep speed loss seems noisy. Anything you recommend?
{'loss': 5.4677, 'learning_rate': 0.0, 'epoch': 0.02}
{'loss': 0.9166, 'learning_rate': 0.0, 'epoch': 0.03}
{'loss': 0.6483, 'learning_rate': 0.0, 'epoch': 0.05}
6%|ββββββββββ | 1999/32170 [2:21:21<35:31:11, 4.24s/it][2021-11-16 18:02:53,513] [INFO] [logging.py:68:log_dist] [Rank 0] step=2000, skipped=1999, lr=[0.0], mom=[[0.9, 0.999]]
[2021-11-16 18:02:53,513] [INFO] [timer.py:157:stop] 0/2000, SamplesPerSec=5.674303086219585
{'loss': 1.1347, 'learning_rate': 0.0, 'epoch': 0.06}
{'loss': 0.6642, 'learning_rate': 0.0, 'epoch': 0.08}
{'loss': 1.0864, 'learning_rate': 0.0, 'epoch': 0.09}
{'loss': 0.4922, 'learning_rate': 0.0, 'epoch': 0.11}
12%|βββββββββββββββββββ | 3999/32170 [4:42:30<33:11:13, 4.24s/it][2021-11-16 20:24:02,592] [INFO] [logging.py:68:log_dist] [Rank 0] step=4000, skipped=3999, lr=[0.0], mom=[[0.9, 0.999]]
[2021-11-16 20:24:02,593] [INFO] [timer.py:157:stop] 0/4000, SamplesPerSec=5.679144072985121
{'loss': 1.6662, 'learning_rate': 0.0, 'epoch': 0.12}
{'loss': 1.4723, 'learning_rate': 0.0, 'epoch': 0.14}
{'loss': 0.5988, 'learning_rate': 0.0, 'epoch': 0.16}
{'loss': 1.1777, 'learning_rate': 0.0, 'epoch': 0.17}
19%|βββββββββββββββββββββββββββββ | 5999/32170 [7:03:38<30:45:21, 4.23s/it][2021-11-16 22:45:10,765] [INFO] [logging.py:68:log_dist] [Rank 0] step=6000, skipped=5999, lr=[0.0], mom=[[0.9, 0.999]]
[2021-11-16 22:45:10,765] [INFO] [timer.py:157:stop] 0/6000, SamplesPerSec=5.68092264980687
{'loss': 0.9843, 'learning_rate': 0.0, 'epoch': 0.19}
{'loss': 0.3419, 'learning_rate': 0.0, 'epoch': 0.2}
{'loss': 1.1882, 'learning_rate': 0.0, 'epoch': 0.22}
May I also ask why you used LR = 3e-5 when HF page itself notes
Oh, that was a totally random setting which makes no impact on the need it was testing (memory usage). I use the same scripts to test many models and most of the time I only care about it working and/or fitting into memory, when I do that particular type of work. I train them for like 50 iterations...
Of course, when training for real, I pay attention to the recommended hparam settings. So please don't use any of the lr-like hparams in my examples for fitting memory as a recommendation for real training.
But let's not mix unrelated things in the same thread. If you'd like to discuss a different topic please kindly open a new issue and we can discuss it there.
@stas00 Hopefully this is relevant. I know you had success on A100 40 GB GPU . I am using deep speed on 4 gpus and I recieve OOM after training for several hours. Any idea as to what I can do here
warnings.warn(formatted_warning, FutureWarning)
{'loss': 6.0737, 'learning_rate': 0.0, 'epoch': 0.02}
{'loss': 0.1926, 'learning_rate': 0.0, 'epoch': 0.04}
{'loss': 0.0399, 'learning_rate': 0.0, 'epoch': 0.06}
8%|βββββββββββββ | 1999/24128 [1:52:11<20:35:01, 3.35s/it][2021-11-22 19:51:55,198] [INFO] [logging.py:69:log_dist] [Rank 0] step=2000, skipped=1999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-22 19:51:55,199] [INFO] [timer.py:181:stop] 0/2000, SamplesPerSec=9.546767962244255
{'loss': 0.0749, 'learning_rate': 0.0, 'epoch': 0.08}
{'loss': 0.408, 'learning_rate': 0.0, 'epoch': 0.1}
{'loss': 0.0354, 'learning_rate': 0.0, 'epoch': 0.12}
{'loss': 0.0341, 'learning_rate': 0.0, 'epoch': 0.15}
17%|ββββββββββββββββββββββββββ | 3999/24128 [3:43:57<18:47:06, 3.36s/it][2021-11-22 21:43:41,103] [INFO] [logging.py:69:log_dist] [Rank 0] step=4000, skipped=3999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-22 21:43:41,103] [INFO] [timer.py:181:stop] 0/4000, SamplesPerSec=9.564911481857864
{'loss': 0.0316, 'learning_rate': 0.0, 'epoch': 0.17}
{'loss': 0.0802, 'learning_rate': 0.0, 'epoch': 0.19}
{'loss': 0.035, 'learning_rate': 0.0, 'epoch': 0.21}
{'loss': 0.1423, 'learning_rate': 0.0, 'epoch': 0.23}
25%|βββββββββββββββββββββββββββββββββββββββ | 5999/24128 [5:35:43<16:52:01, 3.35s/it][2021-11-22 23:35:26,678] [INFO] [logging.py:69:log_dist] [Rank 0] step=6000, skipped=5999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-22 23:35:26,678] [INFO] [timer.py:181:stop] 0/6000, SamplesPerSec=9.571203445125207
{'loss': 0.1107, 'learning_rate': 0.0, 'epoch': 0.25}
{'loss': 0.0467, 'learning_rate': 0.0, 'epoch': 0.27}
{'loss': 0.0802, 'learning_rate': 0.0, 'epoch': 0.29}
{'loss': 0.0706, 'learning_rate': 0.0, 'epoch': 0.31}
33%|ββββββββββββββββββββββββββββββββββββββββββββββββββββ | 7999/24128 [7:27:26<15:00:20, 3.35s/it][2021-11-23 01:27:10,465] [INFO] [logging.py:69:log_dist] [Rank 0] step=8000, skipped=7999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-23 01:27:10,465] [INFO] [timer.py:181:stop] 0/8000, SamplesPerSec=9.574953735862689
{'loss': 0.22, 'learning_rate': 0.0, 'epoch': 0.33}
{'loss': 0.0967, 'learning_rate': 0.0, 'epoch': 0.35}
{'loss': 0.0716, 'learning_rate': 0.0, 'epoch': 0.37}
{'loss': 0.1111, 'learning_rate': 0.0, 'epoch': 0.39}
41%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 9999/24128 [9:19:10<13:10:15, 3.36s/it][2021-11-23 03:18:53,863] [INFO] [logging.py:69:log_dist] [Rank 0] step=10000, skipped=9999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-23 03:18:53,863] [INFO] [timer.py:181:stop] 0/10000, SamplesPerSec=9.577305314814142
{'loss': 0.2233, 'learning_rate': 0.0, 'epoch': 0.41}
43%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 10397/24128 [9:41:24<12:47:24, 3.35s/it]Traceback (most recent call last):
File "./finetune_trainer.py", line 368, in <module>
main()
File "./finetune_trainer.py", line 305, in main
train_result = trainer.train(
File "/home/tuhin.chakr/yes/envs/fairseq/lib/python3.8/site-packages/transformers/trainer.py", line 1316, in train
tr_loss_step = self.training_step(model, inputs)
File "/home/tuhin.chakr/yes/envs/fairseq/lib/python3.8/site-packages/transformers/trainer.py", line 1865, in training_step
loss = self.deepspeed.backward(loss)
File "/home/tuhin.chakr/yes/envs/fairseq/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1708, in backward
self.optimizer.backward(loss)
File "/home/tuhin.chakr/yes/envs/fairseq/lib/python3.8/site-packages/deepspeed/runtime/zero/stage2.py", line 1880, in backward
buf_1 = torch.empty(int(self.reduce_bucket_size),
RuntimeError: CUDA out of memory. Tried to allocate 382.00 MiB (GPU 1; 39.59 GiB total capacity; 36.01 GiB already allocated; 164.94 MiB free; 36.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
My script
export BS=8;
PYTHONPATH=../../src
USE_TF=0
deepspeed --num_gpus=4 ./finetune_trainer.py \
--data_dir /home/tuhin.chakr/gpt3/poetrynew \
--output_dir /local/nlp/temp/poetryT5-11B_new \
--model_name_or_path t5-11b \
--do_train \
--task translation \
--max_source_length 64 \
--max_target_length 64 \
--save_strategy=epoch \
--num_train_epochs 1 \
--per_device_train_batch_size $BS \
--adafactor \
--learning_rate 1e-3 \
--deepspeed /home/tuhin.chakr/gpt3/transformers/tests/deepspeed/ds_config_zero2.json \
--fp16
My config
json = {
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 0
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2.000000e+08,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2.000000e+08,
"contiguous_gradients": true
},
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 8,
"gradient_clipping": 1.0,
"steps_per_print": 2.000000e+03,
"wall_clock_breakdown": false,
"zero_allow_untested_optimizer": true
}
are you monitoring the memory consumption over the duration of the training - is it borderline OOM from the get going or is the memory usage slowly creeping up?
But regardless, you're using only stage-2, and you want stage-3 in this situation. Since if you're not sharding the params, you get only 12 out of 18 bytes sharded per param. Stage-3 is slower than stage-2 since it has to do more work, but if you can't fit into your gpus stage-3 is what you want.
Note that I'm using stage 3 here: https://github.com/huggingface/transformers/issues/9996#issuecomment-856384448
retraining again and this is what my gpu looks like
So this is the state at the beginning of the training, right? Then check it say once in 30min and note the differences - if your application is well written then it shouldn't grow after say a few hundred of iterations, assuming the longest seqlen with widest batch size has been consumed already.
I'm also noticing that you're using a very old version of our examples - finetune_trainer.py
is very old. So it'd be hard to debug this situation if indeed there a gradual memory leak there. In which case I'd recommend to migrate to the recent version of the software.
The snapshot I sent you was after 5 hrs of training. I have 7M samples and max seq len I reduced to 64 from 128. So hoping it works this time. Last time it failed around 40% of training. Its at 22% now
Yes If I still can't make it work I will switch to a recent version of software.
Right, I'm not sure my message is coming across - I'm suggesting to monitor the memory usage through the training.
And that if it OOMs you need to switch to ZeRO-3 and then you should be able to train with a much longer seqlen.
Enabling https://huggingface.co/transformers/performance.html#gradient-checkpointing is another technique to allow for much longer seqlen.
@stas00 many thanks for your guidance. I could finetune 1 epoch. I converted the model to fp32 and saw the output and noticed it's generating garbled text. Now of course this could be bcz its only 1 epoch. But I trained on 772073 samples. Just to be clear I have a T5 3B model trained on same data but using a different code and it works perfecrly, so assuming my data is perfect
It generated something
**' thou sa wrt e the in thee wast the the of the world, a man of resea the earthe, the in the all the that of**
I am wondering what could be the reason, One thing I suspect is why is the loss zero
. as you can see below. I just wanted to see as a proof of concept the generated text as it takes around 24 hours to train 1 epoch. Would you recommend finetuning for more epochs or something else
{'loss': 6.0737, 'learning_rate': 0.0, 'epoch': 0.02}
{'loss': 0.1926, 'learning_rate': 0.0, 'epoch': 0.04}
{'loss': 0.0399, 'learning_rate': 0.0, 'epoch': 0.06}
8%|βββββββββββββ | 1999/24128 [1:52:11<20:35:01, 3.35s/it][2021-11-22 19:51:55,198] [INFO] [logging.py:69:log_dist] [Rank 0] step=2000, skipped=1999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-22 19:51:55,199] [INFO] [timer.py:181:stop] 0/2000, SamplesPerSec=9.546767962244255
{'loss': 0.0749, 'learning_rate': 0.0, 'epoch': 0.08}
{'loss': 0.408, 'learning_rate': 0.0, 'epoch': 0.1}
{'loss': 0.0354, 'learning_rate': 0.0, 'epoch': 0.12}
{'loss': 0.0341, 'learning_rate': 0.0, 'epoch': 0.15}
17%|ββββββββββββββββββββββββββ | 3999/24128 [3:43:57<18:47:06, 3.36s/it][2021-11-22 21:43:41,103] [INFO] [logging.py:69:log_dist] [Rank 0] step=4000, skipped=3999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-22 21:43:41,103] [INFO] [timer.py:181:stop] 0/4000, SamplesPerSec=9.564911481857864
{'loss': 0.0316, 'learning_rate': 0.0, 'epoch': 0.17}
{'loss': 0.0802, 'learning_rate': 0.0, 'epoch': 0.19}
{'loss': 0.035, 'learning_rate': 0.0, 'epoch': 0.21}
{'loss': 0.1423, 'learning_rate': 0.0, 'epoch': 0.23}
25%|βββββββββββββββββββββββββββββββββββββββ | 5999/24128 [5:35:43<16:52:01, 3.35s/it][2021-11-22 23:35:26,678] [INFO] [logging.py:69:log_dist] [Rank 0] step=6000, skipped=5999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-22 23:35:26,678] [INFO] [timer.py:181:stop] 0/6000, SamplesPerSec=9.571203445125207
{'loss': 0.1107, 'learning_rate': 0.0, 'epoch': 0.25}
{'loss': 0.0467, 'learning_rate': 0.0, 'epoch': 0.27}
{'loss': 0.0802, 'learning_rate': 0.0, 'epoch': 0.29}
{'loss': 0.0706, 'learning_rate': 0.0, 'epoch': 0.31}
33%|ββββββββββββββββββββββββββββββββββββββββββββββββββββ | 7999/24128 [7:27:26<15:00:20, 3.35s/it][2021-11-23 01:27:10,465] [INFO] [logging.py:69:log_dist] [Rank 0] step=8000, skipped=7999, lr=[0.0, 0.0], mom=[0.0, 0.0]
[2021-11-23 01:27:10,465] [INFO] [timer.py:181:stop] 0/8000, SamplesPerSec=9.574953735862689
{'loss': 0.22, 'learning_rate': 0.0, 'epoch': 0.33}
{'loss': 0.0967, 'learning_rate': 0.0, 'epoch': 0.35}
{'loss': 0.0716, 'learning_rate': 0.0, 'epoch': 0.37}
{'loss': 0.1111, 'learning_rate': 0.0, 'epoch': 0.39}
why is your 'learning_rate': 0.0
?
@stas00 thats something I don't understand that. As you can see in my script i mentioned 1e-3
My script from transformers repo
export BS=8;
PYTHONPATH=../../src
USE_TF=0
deepspeed --num_gpus=3 ./finetune_trainer.py \
--data_dir /home/tuhin.chakr/gpt3/poetrynew \
--output_dir /local/nlp/temp/poetryT5-11B_new \
--model_name_or_path t5-11b \
--do_train \
--task translation \
--max_source_length 128 \
--max_target_length 128 \
--save_strategy=epoch \
--num_train_epochs 1 \
--per_device_train_batch_size $BS \
--adafactor \
**--learning_rate 1e-3 \**
--deepspeed /home/tuhin.chakr/gpt3/transformers/tests/deepspeed/ds_config_zero2.json \
--fp16
~
My deepspeed config
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"train_batch_size": 24,
"train_micro_batch_size_per_gpu": 8,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
Someone here said the same https://github.com/microsoft/DeepSpeed/issues/1574
I'd be happy to debug this with you, but let's first switch to the current example, which is https://github.com/huggingface/transformers/blob/master/examples/pytorch/translation/run_translation.py - it should be mostly the same with some args renamed - see the README.md for details https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation
e.g. my staple cmd that I use is:
export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus=2 examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06 --evaluation_strategy=steps --do_train --do_eval --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 500 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS --per_device_eval_batch_size $BS --predict_with_generate --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --val_max_target_length 128 --warmup_steps 50 --max_train_samples 500 --max_eval_samples 50 --deepspeed tests/deepspeed/ds_config_zero3.json --fp16
Additionally, please open a new Issue since this discussion is now taking over this already closed issue, so let's give it a dedicated space. Just don't forget to tag me in the new Issue.
Update on my end: with DeepSpeed 0.3.10 it did run successfully through the night on a full job, successfully training and generating the predictions. Amazing work @stas00 et al.
how did you infer bro? got something ?
Could you please tell me where can I find the ds_config.json and finetune_trainer.py? Thank you!
The examples have been renamed and re-organized since the time of this thread, you can find them all here: https://github.com/huggingface/transformers/tree/main/examples/pytorch
e.g. the translation is now at examples/pytorch/translation/run_translation.py
For deepspeed please see: https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration
@stas00 sorry for such question do I understand correctly that every trani example executed 5 seconds? If yes, how many time approx you think tooks training T5-11B from the scratch on such hw?
multiply iteration time by how many batches you plan to feed the model and you will get the total time needed to train any model - as I wasn't part of the t5 training I don't know what their numbers were.
Managed to train t5-11b on 1x 40GB gpu w/ Deepspeed (A100-SXM4-40GB)
Thank you, @PeterAJansen for letting me use your hardware!
Thank you, @jeffra and @samyam, for not believing that it is not possible to train t5-11b on 1x 40GB gpu w/ Deepspeed and supporting me that lead me to find a few bugs in the integration.
Sharing details for those who need.
If you want to try this at home please make sure you use transformers master as some bug fixes were just merged in
Well, it's similar to the t5-3b on 24GB success reported here and here. But this time t5-11b on 1x 40GB gpu (or 4x if you wanted things faster)
As someone asked me before you need a huge amount of general RAM to use ZeRO-Offload for a huge model:
I was using
/usr/bin/time -v program
to get the peak memory measurement - it's theMaximum resident set size
entry in the final report.Question: I don't think
/usr/bin/time
does the right thing for multi-process - I think it only measures the parent process. e.g. with 4x gpus it reported only 102GB RAM, but I clearly saw in top that it was around 240GB. If you have an easy way to measure peak memory that takes into an account forked processes I'm all ears.Batch sizes on one gpu:
I'm referring to these batch sizes in
ds_config.json
:And I tested for 2x and 4x DDP as well, BS=16 OOMed, BS=8 was good so I used that - but could probably squeeze some more.
edit1: later tests show that my test was too short and wasn't getting the CPU Adam optimizer kick in, as it skips the first 20 or so tests because of the overflow. So once it kicks in it takes more GPU memory, so the practical BS is much smaller - I think around 2 on this setup. So most likely you will need to use
BS=2
for real work, until things get optimized even more.edit2: things are getting re-shuffling in the tests, so the default
ds_config.json
file has moved in master to a new, hopefully permanent home. It's now atexamples/tests/deepspeed/ds_config.json
so you will need to adjust the command line to reflect this new location or simply copy it over to where the old one used to be.here is the full benchmark:
Checkpointing should allow making even bigger batch sizes.