Closed XilunWu closed 2 weeks ago
note: this test requires the land of https://github.com/pytorch/pytorch/pull/126924
your perf benchmark seems using batch size =1, can you update with batch_size=4 and update the perf table
@XilunWu The WPS for 8B in your summary still not looking right, I have exact same settings, but the WPS on my side is sth like this:
[rank0]:2024-06-13 12:53:16,156 - root - INFO - step: 1 loss: 12.2550 memory: 33.35GiB(35.09%) wps: 542 mfu: 3.17%
[rank0]:2024-06-13 12:53:16,156 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-13 12:53:42,562 - root - INFO - step: 10 loss: 10.7798 memory: 41.18GiB(43.33%) wps: 2,792 mfu: 16.35%
[rank0]:2024-06-13 12:54:04,065 - root - INFO - step: 20 loss: 9.1087 memory: 41.18GiB(43.33%) wps: 3,812 mfu: 22.32%
[rank0]:2024-06-13 12:54:25,626 - root - INFO - step: 30 loss: 7.9951 memory: 41.18GiB(43.33%) wps: 3,802 mfu: 22.27%
Stack from ghstack (oldest at bottom):
Summary This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.
Test Plan Here's the output of running
CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh
using 4-way Tensor Parallel (tensor_parallel_degree = 4
). Detailed settings:with
norm_type = "rmsnorm"
with
norm_type = "fused_rmsnorm"