pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.28k stars 115 forks source link

[torchtitan][optim] Add fused as an option in train config #355

Closed wz337 closed 3 weeks ago

wz337 commented 1 month ago

With these three PRs landed, we can now support the option fused=True in torchtitan for Adam and AdamW optimizer.

https://github.com/pytorch/pytorch/pull/125369 https://github.com/pytorch/pytorch/pull/126423 https://github.com/pytorch/pytorch/pull/126750

Run performance evaluation on 8 A100 DevGPU: 1000 steps on 1D DP default llama_8b.toml.

Observation: For fused = True and fused = False, we observed similar loss curve and memory usage. wps is + ~100 and mfu is + 1.5-2% when fused = True.

Below are the logs for the last 100 steps for both.

**Fused = False**
[rank0]:2024-06-05 12:45:06,227 - root - INFO - Finished dumping traces in 0.37 seconds
[rank0]:2024-06-05 12:45:37,677 - root - INFO - step: 910  loss:  4.6039  memory: 59.48GiB(75.15%)  wps: 2,217  mfu: 41.16%
[rank0]:2024-06-05 12:46:08,843 - root - INFO - step: 920  loss:  4.6427  memory: 59.48GiB(75.15%)  wps: 2,632  mfu: 48.85%
[rank0]:2024-06-05 12:46:40,052 - root - INFO - step: 930  loss:  4.6339  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:47:11,243 - root - INFO - step: 940  loss:  4.5964  memory: 59.48GiB(75.15%)  wps: 2,631  mfu: 48.84%
[rank0]:2024-06-05 12:47:42,655 - root - INFO - step: 950  loss:  4.6477  memory: 59.48GiB(75.15%)  wps: 2,611  mfu: 48.47%
[rank0]:2024-06-05 12:48:13,890 - root - INFO - step: 960  loss:  4.8137  memory: 59.48GiB(75.15%)  wps: 2,626  mfu: 48.75%
[rank0]:2024-06-05 12:48:45,110 - root - INFO - step: 970  loss:  4.5962  memory: 59.48GiB(75.15%)  wps: 2,628  mfu: 48.78%
[rank0]:2024-06-05 12:49:16,333 - root - INFO - step: 980  loss:  4.5450  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:49:47,561 - root - INFO - step: 990  loss:  4.5840  memory: 59.48GiB(75.15%)  wps: 2,627  mfu: 48.76%
[rank0]:2024-06-05 12:50:18,933 - root - INFO - step: 1000  loss:  4.5351  memory: 59.48GiB(75.15%)  wps: 2,615  mfu: 48.53%
[rank0]:2024-06-05 12:50:23,692 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 12:50:24,041 - root - INFO - Finished dumping traces in 0.35 seconds
[rank0]:2024-06-05 12:50:24,422 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 12:50:26,424 - root - INFO - Training completed

**Fused = True**
[rank0]:2024-06-05 14:55:42,894 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 14:56:13,582 - root - INFO - step: 910  loss:  4.6091  memory: 59.48GiB(75.15%)  wps: 2,341  mfu: 43.46%
[rank0]:2024-06-05 14:56:43,765 - root - INFO - step: 920  loss:  4.6468  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:57:13,971 - root - INFO - step: 930  loss:  4.6365  memory: 59.48GiB(75.15%)  wps: 2,715  mfu: 50.40%
[rank0]:2024-06-05 14:57:44,172 - root - INFO - step: 940  loss:  4.6021  memory: 59.48GiB(75.15%)  wps: 2,716  mfu: 50.41%
[rank0]:2024-06-05 14:58:14,353 - root - INFO - step: 950  loss:  4.6522  memory: 59.48GiB(75.15%)  wps: 2,718  mfu: 50.45%
[rank0]:2024-06-05 14:58:44,536 - root - INFO - step: 960  loss:  4.8163  memory: 59.48GiB(75.15%)  wps: 2,717  mfu: 50.44%
[rank0]:2024-06-05 14:59:14,683 - root - INFO - step: 970  loss:  4.6026  memory: 59.48GiB(75.15%)  wps: 2,721  mfu: 50.51%
[rank0]:2024-06-05 14:59:44,840 - root - INFO - step: 980  loss:  4.5491  memory: 59.48GiB(75.15%)  wps: 2,720  mfu: 50.49%
[rank0]:2024-06-05 15:00:15,009 - root - INFO - step: 990  loss:  4.5859  memory: 59.48GiB(75.15%)  wps: 2,719  mfu: 50.47%
[rank0]:2024-06-05 15:00:45,228 - root - INFO - step: 1000  loss:  4.5396  memory: 59.48GiB(75.15%)  wps: 2,714  mfu: 50.38%
[rank0]:2024-06-05 15:00:49,455 - root - INFO - Dumping traces at step 1000
[rank0]:2024-06-05 15:00:49,756 - root - INFO - Finished dumping traces in 0.30 seconds
[rank0]:2024-06-05 15:00:50,336 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2024-06-05 15:00:52,339 - root - INFO - Training completed
awgu commented 1 month ago

I am curious if we have any experiments to see the performance difference with fused=True.

wz337 commented 1 month ago

I am curious if we have any experiments to see the performance difference with fused=True.

Gonna talk to Tianyu to learn how to run the perf experiments on the new 128 GPUs. This is just adding it to the config to allow it, but the default behavior is still foreach=True.

We can totally wait for the result before landing this.

wanchaol commented 1 month ago

can we add some 8 GPU numbers at least? 128 GPU can be done separately

wz337 commented 3 weeks ago

@wanchaol @awgu Added performance diff in the summary. I think we are comfortable offering this option in torchtitan?

weifengpy commented 2 weeks ago

this PR (foreach=true) shortened opt.step from 2000ms to 200ms. That's +10% e2e QPS on 16 H100 node (16 x 8 GPUs). I might need to refresh 1D and 2D benchmark base on this @drisspg

awgu commented 2 weeks ago

@weifengpy foreach=True used to be the default, so perhaps your package was before https://github.com/pytorch/torchtitan/pull/386 landed. Without https://github.com/pytorch/torchtitan/pull/386, the optimizer would fall back to foreach=False when fused=False. 2000 ms for optimizer step sounds like foreach=False.

weifengpy commented 2 weeks ago

@weifengpy foreach=True used to be the default, so perhaps your package was before #386 landed. Without #386, the optimizer would fall back to foreach=False when fused=False. 2000 ms for optimizer step sounds like foreach=False.

Ah got you. I checked the trace and 2000ms indeed comes from foreach=False