pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.28k stars 115 forks source link

enable TritonFusedRMSNorm with local_map annotation #364

Closed XilunWu closed 2 weeks ago

XilunWu commented 1 month ago

Stack from ghstack (oldest at bottom):

Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

Test Plan Here's the output of running CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh using 4-way Tensor Parallel (tensor_parallel_degree = 4). Detailed settings:

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
  1. with norm_type = "rmsnorm"

    [rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
    [rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
    [rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
    [rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
    [rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
    [rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
    [rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
    [rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
    [rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
    [rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
    [rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
    [rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
    [rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
    [rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
    [rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
    [rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
    [rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
    [rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
    [rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
    [rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
    [rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
  2. with norm_type = "fused_rmsnorm"

    [rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
    [rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
    [rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
    [rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
    [rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
    [rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
    [rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
    [rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
    [rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
    [rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
    [rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
    [rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
    [rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
    [rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
    [rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
    [rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
    [rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
    [rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
    [rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
    [rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
    [rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
XilunWu commented 1 month ago

note: this test requires the land of https://github.com/pytorch/pytorch/pull/126924

wanchaol commented 2 weeks ago

your perf benchmark seems using batch size =1, can you update with batch_size=4 and update the perf table

wanchaol commented 2 weeks ago

@XilunWu The WPS for 8B in your summary still not looking right, I have exact same settings, but the WPS on my side is sth like this:

[rank0]:2024-06-13 12:53:16,156 - root - INFO - step:  1  loss: 12.2550  memory: 33.35GiB(35.09%)  wps: 542  mfu: 3.17%
[rank0]:2024-06-13 12:53:16,156 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-13 12:53:42,562 - root - INFO - step: 10  loss: 10.7798  memory: 41.18GiB(43.33%)  wps: 2,792  mfu: 16.35%
[rank0]:2024-06-13 12:54:04,065 - root - INFO - step: 20  loss:  9.1087  memory: 41.18GiB(43.33%)  wps: 3,812  mfu: 22.32%
[rank0]:2024-06-13 12:54:25,626 - root - INFO - step: 30  loss:  7.9951  memory: 41.18GiB(43.33%)  wps: 3,802  mfu: 22.27%