yuvalkirstain / PickScore

MIT License
373 stars 20 forks source link

Question about training time #9

Closed vishaal27 closed 8 months ago

vishaal27 commented 10 months ago

Hey, thanks for the great work and codebase.

I am trying to train PickScore using the training command you provided. I am training with 8 40GB A100s. I also ensured that the accelerate and deepspeed configs match the ones you provided. However, training for me is very slow (nowhere near the 40 mins you mention in the readme): it takes around 4.5 hours to train on the whole dataset. I am wondering if there were additional settings in the accelerate/deepspeed configs you used to make training that fast?

yuvalkirstain commented 10 months ago

Hi! Can you please try to profile a few training steps? It will be helpful to learn if slowness of training is due to data loading, forward step, or backward step.

vishaal27 commented 10 months ago

Hey, thanks for getting back to me.

These are the different times (in seconds) in one training step:

Forward Pass: 0.2936246395111084,
Backward: 3.9384946823120117, 
Optimizer: 4.0531158447265625e-06,
Load Time: 0.012419700622558594

These use exactly the same config settings as you (including mixed precision BF16). Do these look comparable to your run?

I logged the first few optimisation steps so that you get a clearer picture (manually printed times only after the first two steps for a more stable picture):

[2023-08-12 11:03:01,549][__main__][INFO] - Sync Gradients happened
Steps:   0%|                                                                                                        | 1/4000 [00:23<26:02:32, 23.44s/it, ep=0, gst=1, gstl=1.79, lr=0, mem=25.4, st=1, stl=1.79][2023-08-12 11:03:05,755][__main__][INFO] - Sync Gradients happened
Steps:   0%|                                                                                                  | 2/4000 [00:27<13:28:05, 12.13s/it, ep=0, gst=2, gstl=2.29, lr=3.35e-7, mem=25.4, st=2, stl=2.29][2023-08-12 11:03:11,754][__main__][INFO] - Sync Gradients happened
Steps:   0%|                                                                                                     | 3/4000 [00:33<10:21:29,  9.33s/it, ep=0, gst=3, gstl=1.6, lr=5.3e-7, mem=25.4, st=3, stl=1.6][2023-08-12 11:03:15,973][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:15,975][__main__][INFO] - Times: Forward Pass: 0.3150146007537842,Backward: 3.883713483810425, Optimizer: 5.7220458984375e-06,Load Time: 0.020132780075073242
Steps:   0%|                                                                                                   | 4/4000 [00:37<8:06:58,  7.31s/it, ep=0, gst=4, gstl=1.34, lr=6.69e-7, mem=25.4, st=4, stl=1.34][2023-08-12 11:03:20,218][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:20,220][__main__][INFO] - Times: Forward Pass: 0.31520891189575195,Backward: 3.9158997535705566, Optimizer: 1.049041748046875e-05,Load Time: 0.014117002487182617
Steps:   0%|                                                                                                   | 5/4000 [00:42<6:53:14,  6.21s/it, ep=0, gst=5, gstl=1.81, lr=7.77e-7, mem=25.4, st=5, stl=1.81][2023-08-12 11:03:24,467][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:24,470][__main__][INFO] - Times: Forward Pass: 0.3127315044403076,Backward: 3.916308641433716, Optimizer: 5.245208740234375e-06,Load Time: 0.02049875259399414
Steps:   0%|▏                                                                                                  | 6/4000 [00:46<6:08:50,  5.54s/it, ep=0, gst=6, gstl=2.01, lr=8.65e-7, mem=25.4, st=6, stl=2.01][2023-08-12 11:03:28,698][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:28,700][__main__][INFO] - Times: Forward Pass: 0.2895190715789795,Backward: 3.9220213890075684, Optimizer: 4.76837158203125e-06,Load Time: 0.01908421516418457
Steps:   0%|▏                                                                                                  | 7/4000 [00:50<5:40:13,  5.11s/it, ep=0, gst=7, gstl=1.96, lr=9.39e-7, mem=25.4, st=7, stl=1.96][2023-08-12 11:03:32,943][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:32,945][__main__][INFO] - Times: Forward Pass: 0.2936246395111084,Backward: 3.9384946823120117, Optimizer: 4.0531158447265625e-06,Load Time: 0.012419700622558594
Steps:   0%|▏                                                                                                     | 8/4000 [00:54<5:21:45,  4.84s/it, ep=0, gst=8, gstl=1.33, lr=1e-6, mem=25.4, st=8, stl=1.33][2023-08-12 11:03:37,198][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:37,200][__main__][INFO] - Times: Forward Pass: 0.29284191131591797,Backward: 3.950087785720825, Optimizer: 4.76837158203125e-06,Load Time: 0.012250661849975586
Steps:   0%|▏                                                                                                  | 9/4000 [00:59<5:09:35,  4.65s/it, ep=0, gst=9, gstl=1.51, lr=1.06e-6, mem=25.4, st=9, stl=1.51][2023-08-12 11:03:41,386][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:41,388][__main__][INFO] - Times: Forward Pass: 0.30242371559143066,Backward: 3.8716259002685547, Optimizer: 4.76837158203125e-06,Load Time: 0.014168977737426758
Steps:   0%|▏                                                                                               | 10/4000 [01:03<4:59:56,  4.51s/it, ep=0, gst=10, gstl=2.43, lr=1.11e-6, mem=25.4, st=10, stl=2.43][2023-08-12 11:03:45,606][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:45,607][__main__][INFO] - Times: Forward Pass: 0.3086836338043213,Backward: 3.897958517074585, Optimizer: 5.7220458984375e-06,Load Time: 0.012292861938476562
Steps:   0%|▎                                                                                               | 11/4000 [01:07<4:53:56,  4.42s/it, ep=0, gst=11, gstl=1.93, lr=1.16e-6, mem=25.4, st=11, stl=1.93][2023-08-12 11:03:49,788][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:49,790][__main__][INFO] - Times: Forward Pass: 0.29230690002441406,Backward: 3.87408185005188, Optimizer: 4.5299530029296875e-06,Load Time: 0.01623678207397461
Steps:   0%|▎                                                                                                | 12/4000 [01:11<4:49:02,  4.35s/it, ep=0, gst=12, gstl=1.19, lr=1.2e-6, mem=25.4, st=12, stl=1.19][2023-08-12 11:03:53,972][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:53,974][__main__][INFO] - Times: Forward Pass: 0.3028097152709961,Backward: 3.8689303398132324, Optimizer: 4.5299530029296875e-06,Load Time: 0.012484312057495117
Steps:   0%|▎                                                                                               | 13/4000 [01:15<4:45:39,  4.30s/it, ep=0, gst=13, gstl=1.67, lr=1.24e-6, mem=25.4, st=13, stl=1.67][2023-08-12 11:03:58,210][__main__][INFO] - Sync Gradients happened
[2023-08-12 11:03:58,212][__main__][INFO] - Times: Forward Pass: 0.29731225967407227,Backward: 3.9296011924743652, Optimizer: 4.5299530029296875e-06,Load Time: 0.011604547500610352
Steps:   0%|▎   

According to the 40 min timeline, it seems each step should be taking ~0.6s ideally, but just the backward seems to be ~6 times this time and the bottleneck., do you have any ideas why this might be happening? I wonder if the issue is with gradient accumulation, since currently it seems like the gradients are synced every step and there is an update every step, but looking at the deepspeed config, it seems like this update should only happen once every 16 steps? https://github.com/yuvalkirstain/PickScore/blob/5fa69e812ece53f0f0b7545a0496bd881651e448/trainer/accelerators/deepspeed_accelerator.py#L49C3-L49C3

yuvalkirstain commented 10 months ago

The backward time should take between 1-2x the forward pass. Once you will solve this you will get a ~5-10x speed improvement (this matches the ~6x speed improvement you are hoping to get). Try to (1) Update torch, (2) run on a single GPU. After you see that the backward time is 1-2x the forward pass, see what happens with 8 GPUs. In any case, the issue is either with your virtual env (e.g. pytorch version/installation) or with your hardware.