ROCm / TransformerEngine

Other
12 stars 3 forks source link

MI300X FP8 TE.Linear 2x Slower than AMP BF16 F.Linear #73

Open OrenLeung opened 1 month ago

OrenLeung commented 1 month ago

Problem Description

Even on Real World Llama 2 70B Training Shapes, TE Linear FP8 is 1.5 to 2x slower than AMP BF16 Linear. Do you have any suggestions or magic env flags on how to improve performance? On H100, TE Linear FP8 is way faster than BF16 AMP Linear.

I have attached an reprod & all the relevant versions & installation scripts below.

cc: @hliuca

image

python3 ./gemm.py 
Benchmark results for Realistic GEMM shapes with warmup=30 and repeats=200
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| Shape (M, N, K)     | bf16 torch.matmul   | bf16 F.linear (with bias)   | bf16 F.linear (with bias & amp)   | TE Linear (FP8 autocast)   |
+=====================+=====================+=============================+===================================+============================+
| (16384, 8192, 1280) | 493.0 TFLOPS        | 491.7 TFLOPS                | 420.2 TFLOPS                      | 206.8 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 1024, 8192) | 546.4 TFLOPS        | 470.0 TFLOPS                | 288.6 TFLOPS                      | 137.1 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 8192, 7168) | 567.0 TFLOPS        | 566.3 TFLOPS                | 504.0 TFLOPS                      | 465.3 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 3584, 8192) | 610.0 TFLOPS        | 545.0 TFLOPS                | 430.1 TFLOPS                      | 325.8 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (8192, 8192, 8192)  | 588.3 TFLOPS        | 504.3 TFLOPS                | 443.2 TFLOPS                      | 372.9 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+

Steps to Reproduce

Versions

root@NODENAME:/workspace/llm-train-bench# pip list | grep torch
^[[Apytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241012+rocm6.2
torchvision             0.18.0a0+68ba7ec
root@NODENAME:/workspace/llm-train-bench# pip list | grep transformer
transformer_engine      1.8.0.dev0+691dc23

Install Instruction

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/

RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git
ENV NVTE_FRAMEWORK=pytorch
ENV PYTORCH_ROCM_ARCH=gfx942

RUN cd TransformerEngine && pip install .

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod Script

import time
import torch
import tabulate
from triton.testing import do_bench
import torch.nn.functional as F
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

torch.manual_seed(0)
repeats = 200
warmup = 30
dtype = torch.bfloat16
device = 'cuda'
verbose = False

shapes = [
    (16384, 8192, 1280), # LLama 70B TP8 Shape
    (16384, 1024, 8192), # LLama 70B TP8 Shape
    (16384, 8192, 7168), # LLama 70B TP8 Shape
    (16384, 3584, 8192), # LLama 70B TP8 Shape
    (8192, 8192, 8192) # Square shape
]

results = []

for (m, n, k) in shapes:
    # Matmul benchmark
    a = torch.randn(m, k, device=device, dtype=dtype)
    b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
    nFLOPS = 2 * m * n * k
    ms = do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=repeats)
    tflops_matmul = nFLOPS / ms * 1e-9
    time.sleep(3)

    nFLOPS_with_bias = 2 * m * n * k + m * n  # FLOPs for matmul and addition

    # # Linear (with bias) benchmark using F.linear
    weight_with_bias = torch.randn(n, k, device=device, dtype=dtype)
    bias = torch.randn(n, device=device, dtype=dtype)
    input_tensor = torch.randn(m, k, device=device, dtype=dtype)
    ms_linear_with_bias = do_bench(lambda: F.linear(input_tensor, weight_with_bias, bias=bias), warmup=warmup, rep=repeats)
    tflops_linear_with_bias = nFLOPS_with_bias / ms_linear_with_bias * 1e-9
    time.sleep(0.25)

    # # F.linear with autocast bf16 with a, b, and c being fp32
    a = torch.randn(m, k, device=device, dtype=torch.float32)
    b = torch.randn(n, k, device=device, dtype=torch.float32)
    c = torch.randn(n, device=device, dtype=torch.float32)
    with torch.autocast(dtype=dtype, device_type=device):
        ms_autocast = do_bench(lambda: F.linear(a, b, bias=c), warmup=warmup, rep=repeats)
    tflops_autocast = nFLOPS_with_bias / ms_autocast * 1e-9
    time.sleep(0.25)

    # TE Linear (with FP8 autocast) benchmark
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    input_tensor = torch.randn(m, k, device=device)
    linear_layer = te.Linear(k, n, bias=True).to(device)
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        ms_te_linear = do_bench(lambda: linear_layer(input_tensor), warmup=warmup, rep=repeats)
    tflops_te_linear = nFLOPS_with_bias / ms_te_linear * 1e-9
    time.sleep(0.25)

    # Append the results to the list
    results.append([
        f"({m}, {n}, {k})",
        f"{tflops_matmul:.1f} TFLOPS",
        f"{tflops_linear_with_bias:.1f} TFLOPS",
        f"{tflops_autocast:.1f} TFLOPS",
        f"{tflops_te_linear:.1f} TFLOPS"
    ])

# Print results using tabulate
headers = [
    "Shape (M, N, K)",
    "bf16 torch.matmul",
    "bf16 F.linear (with bias)",
    "bf16 F.linear (with bias & amp)",
    "TE Linear (FP8 autocast)"
]
print(f"Benchmark results for Realistic GEMM shapes with {warmup=} and {repeats=}")
print(tabulate.tabulate(results, headers=headers, tablefmt="grid"))

Operating System

Ubuntu

CPU

AMD CPU

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.2.0

hliuca commented 1 month ago

Hi @OrenLeung I have reported your issues to our dev teams. Thank you.

wenchenvincent commented 1 month ago

Hi @OrenLeung, could you try with NVTE_USE_HIPBLASLT=1 when installing TE? Otherwise, it will use the rocblas path, which does have GEMM fusion and might not have the latest optimization for fp8 GEMM. We will make the hipblasLt path default soon.

OrenLeung commented 1 month ago

Hi @OrenLeung, could you try with NVTE_USE_HIPBLASLT=1 when installing TE? Otherwise, it will use the rocblas path, which does have GEMM fusion and might not have the latest optimization for fp8 GEMM. We will make the hipblasLt path default soon.

hi @wenchenvincent ,

Unfortunately using the HIPBLASLT backend does not have any better results.

I have updated my Dockerfile & TE Build instructions to be done inside the container instead of Dockerfile as you suggested in https://github.com/ROCm/TransformerEngine/issues/74#issuecomment-2414845971

image

$ python ./reprod.py 
Benchmark results for Realistic GEMM shapes with warmup=30 and repeats=200
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| Shape (M, N, K)     | bf16 torch.matmul   | bf16 F.linear (with bias)   | bf16 F.linear (with bias & amp)   | TE Linear (FP8 autocast)   |
+=====================+=====================+=============================+===================================+============================+
| (16384, 8192, 1280) | 487.2 TFLOPS        | 323.8 TFLOPS                | 294.4 TFLOPS                      | 401.9 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 1024, 8192) | 538.3 TFLOPS        | 490.1 TFLOPS                | 291.0 TFLOPS                      | 196.3 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 8192, 7168) | 557.9 TFLOPS        | 515.8 TFLOPS                | 463.4 TFLOPS                      | 473.3 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 3584, 8192) | 592.0 TFLOPS        | 549.7 TFLOPS                | 450.1 TFLOPS                      | 328.4 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (8192, 8192, 8192)  | 573.2 TFLOPS        | 536.4 TFLOPS                | 463.7 TFLOPS                      | 358.3 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+

cc: @hliuca

Docker Image

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

TE install Instructions (done inside docker container)

cd /workspace
git clone --recursive https://github.com/ROCm/TransformerEngine.git
export NVTE_USE_HIPBLASLT=1
export NVTE_FRAMEWORK=pytorch
export PYTORCH_ROCM_ARCH=gfx942
cd TransformerEngine && pip install .
cd /workspace/llm-train-bench
wenchenvincent commented 1 month ago

@OrenLeung I looked into the traces. There are two major reasons why it is slow:

I had a WIP branch for the optimized cast_transpose (https://github.com/ROCm/TransformerEngine/tree/transpose_cast_opt). With this branch, I was getting some better numbers:

image

It will be further improved when the GEMM tuning is done.

OrenLeung commented 1 month ago

hi @wenchenvincent ,

thanks for looking into this. It seems like on these GEMMs, it is somewhat reached parity and is somewhat slightly better with bf16 amp.

On H100 using the exact script, the difference from fp8 Linear is 1.5x to 2x faster than bf16 Linear

image

H100 Dockerfile

FROM nvcr.io/nvidia/pytorch:24.09-py3

RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]
wenchenvincent commented 1 month ago

hi @wenchenvincent ,

thanks for looking into this. It seems like on these GEMMs, it is somewhat reached parity and is somewhat slightly better with bf16 amp.

On H100 using the exact script, the difference from fp8 Linear is 1.5x to 2x faster than bf16 Linear

image

H100 Dockerfile

FROM nvcr.io/nvidia/pytorch:24.09-py3

RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Yes. Those fp8 GEMMs were not using the optimal kernels, with the hipblasLt library in that docker container. I filed a tuning request and we will check the hipblaslt version that would have the optimal kernels.

hliuca commented 1 month ago

@OrenLeung @wenchenvincent we usually use the latest hipblaslt. the image here use latest hipblaslt for fp8, https://hub.docker.com/r/rocm/vllm-dev/tags

OrenLeung commented 1 month ago

@OrenLeung @wenchenvincent we usually use the latest hipblaslt. the image here use latest hipblaslt for fp8, https://hub.docker.com/r/rocm/vllm-dev/tags

hi @hliuca @wenchenvincent , can you please verify on your end that this docker image has improved performance for transformer engine fp8 before I try it on my end?

hliuca commented 1 month ago

@OrenLeung I will run. Thanks.

OrenLeung commented 1 month ago

Thanks @hliuca !

hliuca commented 1 month ago

@OrenLeung I tried and got this,

image

OrenLeung commented 1 month ago

Hi @hliuca ,

Thanks for trying the latest vllm image but unfortunately these number seem worse than @wenchenvincent 's number in terms of the absolute TFLOP/s and also in terms of the fp8 speedup compared to bf16.

I wonder if using @wenchenvincent 's branch on the vllm image would gain the best results?

image

OrenLeung commented 1 month ago

Here is my updated spreadsheet with H100 speed up too. On H100, fp8 provides ~1.5-1.7x speedup when K=~8k and 1.4x speedup when K is small image

hliuca commented 1 month ago

@OrenLeung Wen's TE owner. Wen knows TE better than me. Please use Wen's data :-)