Open OrenLeung opened 1 month ago
@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (https://github.com/ROCm/TransformerEngine/pull/66). Here are the numbers that I got with this PR:
8xMI300X FSDP TE FP8 (batch size 2): 442TFLOPs 8xMI300X FSDP TE FP8 (batch size 4): 474TFLOPs
@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR:
8xMI300X FSDP TE FP8 (batch size 2): 442TFLOPs 8xMI300X FSDP TE FP8 (batch size 4): 474TFLOPs
Thanks @wenchenvincent for looking into this. This is quite competitive to h100 on perf per TCO basis sinec mi300x TCO is 78% of an h100. But unfortunately it is not competitive to H200, any other PRs you have in the pipeline that would help?
here is my results for this llama3 8B full model:
Full Response in the llama3 70B proxy gh issue https://github.com/ROCm/TransformerEngine/issues/78#issuecomment-2418538437
cc: @hliuca
Thank you again @OrenLeung We like your data, and they will be great reference for our future optimization goals. We will see if we can pass H200 using MI300x :-)
hi @hliuca ,
I am glad we were able to provide an optimization goal.
Please note that all of our H100 & H200 that we shared are preliminary and will probably improve too as I do tuning on them.
Also please note that we are benchmarking & evaluating AMD/Nvidia on other real world transformer models and real world GEMM training shapes that we have not shared with Nvidia or AMD to ensure that these patches to pytorch, te, hipblaslt, etc made are generalizable.
I now get an preliminary number of 464 TFLOP/s/GPU (batch = 4) after #66 got merged to main on our internal codebase for this model.
After 32 Warmup: Mean TFLOP/s: 464.33 Mean MFU: 17.79%
@wenchenvincent & your team, very impressive work to be able to boost perf by 25% in less than 7 days !
it seems like it is competitive to H100 on perf per TCO but still not on pure performance.
@OrenLeung We have the optimized cast transpose Triton kernel merged in. And with that, I got the following improvement:
8xMI300X FSDP TE FP8 (batch size 4): 475.76 TFLOP/s -> 523.27 TFLOP/s
Problem Description
Llama3 8B FP8 OOMs at the same batch size as BF16. I need to decrease the batch size to
2
for it to not OOM. At batch size 2, TE FP8 is 21% slower than torch compile BF16 nightly.I have verified that on H100, TE FP8 is able to fit the same batch size as BF16 and results in an 11% increase in perf for this model.
preliminary Results
Commands:
python ./train_fsdp_llama_8b.py --bsz=2
python ./train_fsdp_llama_8b.py --bsz=4
cc: @hliuca
Operating System
Ubuntu
CPU
AMD CPU
GPU
MI300X
ROCm Version
ROCm 6.2.0
ROCm Component
No response
Steps to Reproduce
Versions
Docker Image
TE install Instructions (done inside docker container)
Reprod
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response