Open OrenLeung opened 1 month ago
Hi @OrenLeung I have reported your issues to our dev teams. Thank you.
Hi @OrenLeung, could you try with NVTE_USE_HIPBLASLT=1 when installing TE? Otherwise, it will use the rocblas path, which does have GEMM fusion and might not have the latest optimization for fp8 GEMM. We will make the hipblasLt path default soon.
Hi @OrenLeung, could you try with NVTE_USE_HIPBLASLT=1 when installing TE? Otherwise, it will use the rocblas path, which does have GEMM fusion and might not have the latest optimization for fp8 GEMM. We will make the hipblasLt path default soon.
hi @wenchenvincent ,
Unfortunately using the HIPBLASLT backend does not have any better results.
I have updated my Dockerfile & TE Build instructions to be done inside the container instead of Dockerfile
as you suggested in https://github.com/ROCm/TransformerEngine/issues/74#issuecomment-2414845971
$ python ./reprod.py
Benchmark results for Realistic GEMM shapes with warmup=30 and repeats=200
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| Shape (M, N, K) | bf16 torch.matmul | bf16 F.linear (with bias) | bf16 F.linear (with bias & amp) | TE Linear (FP8 autocast) |
+=====================+=====================+=============================+===================================+============================+
| (16384, 8192, 1280) | 487.2 TFLOPS | 323.8 TFLOPS | 294.4 TFLOPS | 401.9 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 1024, 8192) | 538.3 TFLOPS | 490.1 TFLOPS | 291.0 TFLOPS | 196.3 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 8192, 7168) | 557.9 TFLOPS | 515.8 TFLOPS | 463.4 TFLOPS | 473.3 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 3584, 8192) | 592.0 TFLOPS | 549.7 TFLOPS | 450.1 TFLOPS | 328.4 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (8192, 8192, 8192) | 573.2 TFLOPS | 536.4 TFLOPS | 463.7 TFLOPS | 358.3 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
cc: @hliuca
FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
RUN apt install nano
RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic pybind11
RUN pip3 uninstall -y torch
RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
WORKDIR /workspace/llm-train-bench/
CMD ["/usr/bin/bash"]
cd /workspace
git clone --recursive https://github.com/ROCm/TransformerEngine.git
export NVTE_USE_HIPBLASLT=1
export NVTE_FRAMEWORK=pytorch
export PYTORCH_ROCM_ARCH=gfx942
cd TransformerEngine && pip install .
cd /workspace/llm-train-bench
@OrenLeung I looked into the traces. There are two major reasons why it is slow:
I had a WIP branch for the optimized cast_transpose (https://github.com/ROCm/TransformerEngine/tree/transpose_cast_opt). With this branch, I was getting some better numbers:
It will be further improved when the GEMM tuning is done.
hi @wenchenvincent ,
thanks for looking into this. It seems like on these GEMMs, it is somewhat reached parity and is somewhat slightly better with bf16 amp.
On H100 using the exact script, the difference from fp8 Linear is 1.5x to 2x faster than bf16 Linear
FROM nvcr.io/nvidia/pytorch:24.09-py3
RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic
WORKDIR /workspace/llm-train-bench/
CMD ["/usr/bin/bash"]
hi @wenchenvincent ,
thanks for looking into this. It seems like on these GEMMs, it is somewhat reached parity and is somewhat slightly better with bf16 amp.
On H100 using the exact script, the difference from fp8 Linear is 1.5x to 2x faster than bf16 Linear
H100 Dockerfile
FROM nvcr.io/nvidia/pytorch:24.09-py3 RUN pip install uv RUN uv pip install --system ipython pytest fire pydantic WORKDIR /workspace/llm-train-bench/ CMD ["/usr/bin/bash"]
Yes. Those fp8 GEMMs were not using the optimal kernels, with the hipblasLt library in that docker container. I filed a tuning request and we will check the hipblaslt version that would have the optimal kernels.
@OrenLeung @wenchenvincent we usually use the latest hipblaslt. the image here use latest hipblaslt for fp8, https://hub.docker.com/r/rocm/vllm-dev/tags
@OrenLeung @wenchenvincent we usually use the latest hipblaslt. the image here use latest hipblaslt for fp8, https://hub.docker.com/r/rocm/vllm-dev/tags
hi @hliuca @wenchenvincent , can you please verify on your end that this docker image has improved performance for transformer engine fp8 before I try it on my end?
@OrenLeung I will run. Thanks.
Thanks @hliuca !
@OrenLeung I tried and got this,
Hi @hliuca ,
Thanks for trying the latest vllm image but unfortunately these number seem worse than @wenchenvincent 's number in terms of the absolute TFLOP/s and also in terms of the fp8 speedup compared to bf16.
I wonder if using @wenchenvincent 's branch on the vllm image would gain the best results?
Here is my updated spreadsheet with H100 speed up too. On H100, fp8 provides ~1.5-1.7x speedup when K=~8k and 1.4x speedup when K is small
@OrenLeung Wen's TE owner. Wen knows TE better than me. Please use Wen's data :-)
Problem Description
Even on Real World Llama 2 70B Training Shapes, TE Linear FP8 is 1.5 to 2x slower than AMP BF16 Linear. Do you have any suggestions or magic env flags on how to improve performance? On H100, TE Linear FP8 is way faster than BF16 AMP Linear.
I have attached an reprod & all the relevant versions & installation scripts below.
cc: @hliuca
Steps to Reproduce
Versions
Install Instruction
Reprod Script
Operating System
Ubuntu
CPU
AMD CPU
GPU
AMD Instinct MI300X
ROCm Version
ROCm 6.2.0