Open OrenLeung opened 1 month ago
I will report this. Thanks Oren.
Thanks @hliuca ,
For further context, On MI300X BF16 torch.compile nightly, i get the following preliminary results:
In the reprod script, it is batch size = 10, I have can confirm that batch size 12 also causes gpucore dump
Interestingly when I do batch size 2 I do not gpu core dump but at this small of a batch size, the TFLOP/s/GPU is 491.22, which is 6% slower than bf16 at batch size 12. preliminary
python ./train_fsdp_llama_70_reprod.py --bsz=2
python ./train_fsdp_llama_70_reprod.py --bsz=4
python ./train_fsdp_llama_70_reprod.py
@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (https://github.com/ROCm/TransformerEngine/pull/66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU
@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU
hi @wenchenvincent ,
Thanks for looking into this! Do you have an estimated ETA on when #66 will be merged? Since this is such a big PR, I will probably have to wait till it hits the main branch before I re-test. Probably will wait till #69 & transpose_cast_opt
branch merge too.
I was also wondering which Dockerfile
image you are using as the base image to obtain these results? And is this base image publicly accessible?
From your results, it does seem like your fp8 has better results than mi300x bf16. We estimate that TCO of mi300x is 78% of an h100. So to get competitve perf per $ results vs h100, mi300x fp8 will probably need to hit 742.2 TFLOP/s/GPU.
Is there other PRs or thoughts you have that would potentially help improve performance of mi300x te fp8?
cc: @hliuca
Here is my preliminary numbers on this gh issue's model (llama3 70B 4 Layer Proxy):
@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU
hi @wenchenvincent ,
Thanks for looking into this! Do you have an estimated ETA on when #66 will be merged? Since this is such a big PR, I will probably have to wait till it hits the main branch before I re-test. Probably will wait till #69 &
transpose_cast_opt
branch merge too.I was also wondering which
Dockerfile
image you are using as the base image to obtain these results? And is this base image publicly accessible?From your results, it does seem like your fp8 has better results than mi300x bf16. We estimate that TCO of mi300x is 78% of an h100. So to get competitve perf per $ results vs h100, mi300x fp8 will probably need to hit 742.2 TFLOP/s/GPU.
Is there other PRs or thoughts you have that would potentially help improve performance of mi300x te fp8?
cc: @hliuca
Here is my numbers on this gh issue's model (llama3 70B 4 Layer Proxy):
- 8xMI300X BF16 batch size 8: 508 TFLOP/s/GPU
- 8xMI300X BF16 batch size 10: 512.64 TFLOP/s/GPU
- 8xMI300X BF16 batch size 12: 518.19 TFLOP/s/GPU
- 8xMI300X BF16 batch size 14: OOM
- 8xH100 BF16 batch size 2: 649.02 TFLOP/s/GPU
- 8xH100 BF16 batch size 4: 687.13 TFLOP/s/GPU
- 8xH100 TE FP8 batch size 2: 951.61 TFLOP/s/GPU
- 8xH100 TE FP8 batch size 4: 759.99 TFLOP/s/GPU
@OrenLeung #66 only needs a few minor changes and the bottleneck for merging it was our CI capability... But I expect that it would be merged this week.
I was using the same docker image that you used for producing the numbers.
I haven't got a chance to dump the traces of this model run yet, but I suspect that it might also suffer from the issue with fp8 cast transpose and some fp8 GEMM might not be tuned yet. So potentially the fp8 cast transpose optimization and fp8 GEMM tuning would further improve the performance.
Furthermore here is the preliminary H200 numbers. To be competitive with H200 on a perf per TCO basis, AMD needs to be at 910 TFLOP/s/GPU.
Thank you Oren for providing H200 data. These data are very valuable and helpful. Our TE team and other teams are actively working on all the issues you have filed.
Thank you Oren for providing H200 data. These data are very valuable and helpful. Our TE team and other teams are actively working on all the issues you have filed.
hi @hliuca ,
I am glad we were able to provide an optimization goal.
Please note that all of our H100 & H200 that we shared are preliminary and will probably improve too as I do tuning on them.
Also please note that we are benchmarking & evaluating AMD/Nvidia on other real world transformer models and real world GEMM training shapes that we have not shared with Nvidia or AMD to ensure that these patches to pytorch, te, hipblaslt, etc made are generalizable.
Yes @OrenLeung totally understand. Thank you for driving us doing better job.
After https://github.com/ROCm/TransformerEngine/pull/66 merged to main, I now get a prelimary number of 716.97 TFLOP/s/GPU on my internal codebase
After 32 Warmup: Mean TFLOP/s: 716.97 Mean MFU: 27.47%
Great work! @wenchenvincent !
I assume once triton transpose cast fused op & v3 ck attn merges, it will closer to H100's fp8 951.61 TFLOP/s/GPU
After #66 merged to main, I now get a prelimary number of 716.97 TFLOP/s/GPU on my internal codebase
After 32 Warmup: Mean TFLOP/s: 716.97 Mean MFU: 27.47%
Great work! @wenchenvincent !
I assume once triton transpose cast fused op & v3 ck attn merges, it will closer to H100's fp8 951.61 TFLOP/s/GPU
@OrenLeung Thank you!
@wangye805 had run this model on a different machine and he was getting 747 TFLOP/s. We're investigate why that system could give better performance and hope to make it reproducible.
Yeah, triton cast transpose should be give further improvement. And fp8 GEMM tuning in hipblasLt library and CK FA v3 should give more improvements. But for latter two, we will need to check the timeline internally.
@wenchenvincent interesting that a different machine gives a different TFLOP/s.
Note that before step 16, the TFLOPs in the reprod script usually fluctuates (as it warms up and does grad accum every 8 steps)
In my internal codebase, I usually do warmup of 32 steps then take the mean over 50 steps to get an accurate measurement of what the realistic TFLOP/s would be.
@wenchenvincent interesting that a different machine gives a different TFLOP/s.
Note that before step 16, the TFLOPs in the reprod script usually fluctuates (as it warms up and does grad accum every 8 steps)
In my internal codebase, I usually do warmup of 32 steps then take the mean over 50 steps to get an accurate measurement of what the realistic TFLOP/s would be.
It could be that the other machine has the newer version of kernel driver. And there are some system config tuning that might impact performance as well: https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/system.html
@wenchenvincent that is quite interesting tho my from understanding most of those knobs in the system tuning guide don't really effect text only transformer based models much since this class of models have very small DtoH and HtoD transfer and don't really use the CPU much. so tuning NUMA (NPS1, NPS4, etc.), etc doesn't really effect the performance.
I can see how those knobs will affect cpu dataloader heavy & heavy HtoD transfer models like image or video.
@wenchenvincent that is quite interesting tho my from understanding most of those knobs in the system tuning guide don't really effect text only transformer based models much since this class of models have very small DtoH and HtoD transfer and don't really use the CPU much. so tuning NUMA (NPS1, NPS4, etc.), etc doesn't really effect the performance.
I can see how those knobs will affect cpu dataloader heavy & heavy HtoD transfer models like image or video.
@OrenLeung Those knobs are for general MI300X system tuning. The most relevant knob to the GPU would be this one: https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html#deterministic-clock Sometimes, using the default frequency of 2100MHz for some workload would trigger PCC (Peak Current Control) event lowering the attainable GPU frequency.
Unfortunately, the machine that produced the better perf has been down in the past two days for maintenance and upgrade. Once it is up, we will continue to investigate why it could produce better numbers.
@OrenLeung Also, I think I might have forgotten to mention that we can use autotuning in TE to select the best performing kernels from hipBlasLt for specific GEMM size (if there are varieties of kernels for a specific gemm size): https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#gemm-tuning-with-hipblaslt
@OrenLeung Also, I think I might have forgotten to mention that we can use autotuning in TE to select the best performing kernels from hipBlasLt for specific GEMM size (if there are varieties of kernels for a specific gemm size): https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#gemm-tuning-with-hipblaslt
The perf number that I got was without autotuning though. Once we get the machine back up, we will try with autotuning to see how much we can get.
@wenchenvincent nice! I also seen that there is an autotuning storage PR, what was the timeline for that? Such that we don't need to autotone for every run and can just cache the optimal gemm selection
@wenchenvincent nice! I also seen that there is an autotuning storage PR, what was the timeline for that? Such that we don't need to autotone for every run and can just cache the optimal gemm selection
@OrenLeung The PR is under review and we're looking to merge it end of this week or early next week.
@OrenLeung We have the optimized cast transpose Triton kernel merged in. And with that, I got the following improvement:
8xMI300X FP8 TE batch size 10: 701 TFLOP/s -> 751.88 TFLOP/s
One of my colleagues got better number like 795 TFLOP/s with different machines and different dockers. I will check to see if I can attain that to reproduce his numbers.
hi @wenchenvincent thanks! can you send over the dockerfile?
Hi @OrenLeung
Attached please find a dockerfile. I am working with dev teams to provide a final dockerfile in next few days. Meanwhile, if you like, you may try the following dockerfile, which provides nice perf. Thank you.
Problem Description
On Llama3 70B Proxy Model, the training stalls & gpucore dumps. The gpucore dumps are 41GByte per GPU thus i am unable to send it. Probably easier for yall to reprod this error on your end to get the gpucore dump.
I have verified on H100, te fp8 for llama3 70B fsdp 4 layer model model trains perfectly fine with a 38% TFLOP/s/GPU increase compared to bf16 torch.compile
cc: @hliuca
Operating System
Ubuntu
CPU
AMD CPU
GPU
MI300X
ROCm Version
ROCm 6.2.0
ROCm Component
No response
Steps to Reproduce
Docker Image
TE install Instructions (done inside docker container)
Reprod Script
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response