ROCm / TransformerEngine

Other
7 stars 1 forks source link

[1xMI300X] GPT-2 XL 1.5B FP8 Training ~30% slower than H100 FP8 #72

Open OrenLeung opened 1 week ago

OrenLeung commented 1 week ago

Problem Description

Hi AMD team,

When trying to do FP8 Training on MI300X, it is extremely slower due to extremely high cpu overhead taking up more than 81% of the time. As you can see from the profile, most of the time is spent in CPU & doing hipFree. On GPT-2 XL 1.5B, TFLOP/s is at 22 TFLOP/s. This is 10x slower than mi300x bf16.

For Comparsion, On H100 GPT-2 XL 1.5B, FP8 makes it to be 1.3x faster than BF16 H100. Not slower.

The Reprod Script is attached Below & can be ran using NVTE_FUSED_ATTN_CK=0 python3 ./train.py

image

image

cc: @hliuca

Steps to Reproduce

Versions

root@NODENAME:/workspace/llm-train-bench# pip list | grep torch
^[[Apytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241012+rocm6.2
torchvision             0.18.0a0+68ba7ec
root@NODENAME:/workspace/llm-train-bench# pip list | grep transformer
transformer_engine      1.8.0.dev0+691dc23

Install Instructions

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/

RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git
ENV NVTE_FRAMEWORK=pytorch
ENV PYTORCH_ROCM_ARCH=gfx942

RUN cd TransformerEngine && pip install .

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod GPT2 XL 1.5B Training

import contextlib

import torch
import torch.nn.functional as F
import torch.nn as nn

from pydantic.dataclasses import dataclass

@dataclass
class GPTConfig:
    n_layers: int    # L
    n_heads: int     # H
    d_embd: int      # E
    max_seq_len: int = 1024
    vocab_size: int  = 50304 # V
    arch_name: str = 'gpt'

    @staticmethod
    def estimate_flops_per_token(model, config):
        # get param count
        N = sum(p.numel() for p in model.parameters())

        # print param count in B
        print(f"Param count: {N/1e9}B")

        head_dim = config['d_embd'] // config['n_heads'] 

        flops_per_token = 6 * N + 12 * config['n_layers'] * config['n_heads'] * head_dim * config['max_seq_len']

        return flops_per_token

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)

        # self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        import transformer_engine.pytorch as te
        self.tsfmr_blks = nn.ModuleList(te.TransformerLayer(
                    d_embd,
                    d_embd * 4,
                    kwargs['n_heads'],
                    layer_number=i+1,
                    # Optional, for speedups
                    fuse_qkv_params=True,
                    attn_input_format='bshd'
                ) 
                for i in range(n_layers)                       
                )

        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def train(
    gpu_id: int = 0,
    bsz: int = 8,
    grad_acc_steps: int = 8,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(gpu_id)

    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
    }

    cfg_m = GPTConfig(**cfg_json)
    model = GPT(**cfg_json).to(gpu_id)

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    flops_per_token = cfg_m.estimate_flops_per_token(model, cfg_json)
    flops_per_iter = flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2600e12

    model.train()

    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
    fp8_format = Format.HYBRID
    # Reasonable default setting
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    # Note: wrapped ctx in a function because the te.fp8_autocast object cannot be reused as a context for some reason.
    @contextlib.contextmanager
    def ctx():
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                yield

    with ctx():
         for step_idx in range(100):
            input_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
            label_BT = torch.randint(50304, [8, 1024], dtype=torch.int64).to('cuda:0')
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
            loss /= grad_acc_steps
            loss.backward()

            if (step_idx + 1) % grad_acc_steps == 0:  # Assume n_steps % grad_acc_steps == 0
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            print(f'{(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')

if __name__ == '__main__':
    import fire
    fire.Fire(train)

Operating System

Ubuntu

CPU

AMD CPU

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.2.0

hliuca commented 1 week ago

Hi @OrenLeung this has been reported. Thank you.

OrenLeung commented 1 week ago

For reference numbers, GPT2-XL 1.5B BF16 Pytorch.compile nightly at the same batch size as the reprod script has:

H100 Transformer Engine FP8 gets 1.3x faster than H100 BF16

Even when building with NVTE_USE_HIPBLASLT=1 the reprod script only gets:

This means that AMD FP8 is slower by 20%!!

hliuca commented 1 week ago

Hi @OrenLeung we have more people on the issues your reported and we will drive the fixes. Thank you.

wangye805 commented 1 week ago

Hi @OrenLeung, I tried to run the same script on our single GPU H100 machine but I only got around 100 TFLOPs: image Can you provide the reproduce instruction for H100 (to get 638.4 TFLOPs)?

Thanks.

LiGuihong commented 1 week ago

@OrenLeung Hi, I tried to reproduce your issue. I got 166TFLOPS on MI300. The major difference is I am using 2.6.0.dev20241014 since I cannot find the 20241012 version. May you please provide the system configuration on your side? Thanks.

OrenLeung commented 1 week ago

hi @wangye805 ,

It seems like you are an H100 PCIe 350W card while I am on an H100 SXM 700W card

638.4 TFLOP/s/GPU was from H200 at batch size=38. I apologize for the error.

I have updated the following script using batch size = 14. Feel free to use a larger batch size if that helps mi300x TFLOP/s/GPU

to change batch size just do python ./reprod.py --bsz=BATCH_SIZE

H100/H200 Dockerfile

FROM nvcr.io/nvidia/pytorch:24.09-py3

RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod Script

import contextlib

import torch
import torch.nn.functional as F
import torch.nn as nn

from pydantic.dataclasses import dataclass

@dataclass
class GPTConfig:
    n_layers: int    # L
    n_heads: int     # H
    d_embd: int      # E
    max_seq_len: int = 1024
    vocab_size: int  = 50304 # V
    arch_name: str = 'gpt'

    @staticmethod
    def estimate_flops_per_token(model, config):
        # get param count
        N = sum(p.numel() for p in model.parameters())

        # print param count in B
        print(f"Param count: {N/1e9}B")

        head_dim = config['d_embd'] // config['n_heads'] 

        flops_per_token = 6 * N + 12 * config['n_layers'] * config['n_heads'] * head_dim * config['max_seq_len']

        return flops_per_token

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)

        # self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        import transformer_engine.pytorch as te
        self.tsfmr_blks = nn.ModuleList(te.TransformerLayer(
                    d_embd,
                    d_embd * 4,
                    kwargs['n_heads'],
                    layer_number=i+1,
                    # Optional, for speedups
                    fuse_qkv_params=True,
                    attn_input_format='bshd'
                ) 
                for i in range(n_layers)                       
                )

        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def train(
    gpu_id: int = 0,
    bsz: int = 14,
    grad_acc_steps: int = 8,
):
    torch.manual_seed(3985)
    torch.cuda.set_device(gpu_id)

    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
    }

    cfg_m = GPTConfig(**cfg_json)
    model = GPT(**cfg_json).to(gpu_id)

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    flops_per_token = cfg_m.estimate_flops_per_token(model, cfg_json)
    flops_per_iter = flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2600e12

    model.train()

    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
    fp8_format = Format.HYBRID
    # Reasonable default setting
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    # Note: wrapped ctx in a function because the te.fp8_autocast object cannot be reused as a context for some reason.
    @contextlib.contextmanager
    def ctx():
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
                yield

    with ctx():
         for step_idx in range(100):
            input_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
            label_BT = torch.randint(50304, [bsz, 1024], dtype=torch.int64).to('cuda:0')
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
            loss /= grad_acc_steps
            loss.backward()

            if (step_idx + 1) % grad_acc_steps == 0:  # Assume n_steps % grad_acc_steps == 0
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            print(f'{(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')

if __name__ == '__main__':
    import fire
    fire.Fire(train)
OrenLeung commented 1 week ago

@OrenLeung Hi, I tried to reproduce your issue. I got 166TFLOPS on MI300. The major difference is I am using 2.6.0.dev20241014 since I cannot find the 20241012 version. May you please provide the system configuration on your side? Thanks.

Hi, it seems like we have similar results. I am getting 174 TFLOP/s/GPU on MI300X using the recommended hipblaslt backend. unfortunately still slower than bf16 mi300x

The system setup is provided in this gh issue already. Is there any specific information that you like to know?

wangye805 commented 1 week ago

I found a bug in use_fused_attention filtering and sent out PR#81 to fix this issue. With ck fused attn enabled, I can see our MI300X achieves 250 TFLOP/s with hipblaslt, compared to 300 TFLOP/s using H100, if we set the batch size to 8.

We will keep working on other optimizations to match the performance

OrenLeung commented 1 week ago

hi @wangye805

Thanks for the fix! 250TFLOP/s is much better now.

Though, it still about 100 TFLOP/s off from H100 at the same batch size (bsz=8). and about 2.5x slower than H200 at bsz=38

here is my preliminary single gpu H100 700W SXM results:

cc: @hliuca

OrenLeung commented 1 week ago

hi @wangye805 ,

Can we keep this issue open till mi300x single gpu is able to match h100?

hliuca commented 1 week ago

@OrenLeung I guess if @wangye805 increases bsz, the perf will increase too.

OrenLeung commented 1 week ago

@OrenLeung I guess if @wangye805 increases bsz, the perf will increase too.

@hliuca

currently when both are at batch size 8, MI300x get 250TFLOP/s and H100 gets 345.41 TFLOP/s, so about a 30% difference.

But intuition is that this 30% difference will stay the same as we increase batch size for both

hliuca commented 1 week ago

Hi @OrenLeung I am not very sure about this. For many LLM I see, when we keep increasing workload, the gap will get smaller and smaller, and we can pass. Of course, that also depends on optimizations. We will keep working on this.

wangye805 commented 1 week ago

@OrenLeung reopened this issue until we can match H100

wangye805 commented 1 day ago

With hipblaslt auto tunning, I can get 370 tflop/s with batch size 50

OrenLeung commented 1 day ago

Thanks @wangye805, can you share me the exact env flags and what value should I set them to turn hipBlasLt auto tuning?

wangye805 commented 1 day ago

@OrenLeung You can use the following envs to turn on hipblaslt tunning: TE_HIPBLASLT_TUNING_RUN_COUNT=20 TE_HIPBLASLT_TUNING_ALGO_COUNT=1000

OrenLeung commented 1 day ago

With hipblaslt auto tunning, I can get 370 tflop/s with batch size 50

@wangye805 at giant batch size, h100 can get 490 TFLOP/s, h200 can get 630 TFLOP/s. good improvement but seems like still 120 TFLOP/s difference

OrenLeung commented 1 day ago

@OrenLeung You can use the following envs to turn on hipblaslt tunning: TE_HIPBLASLT_TUNING_RUN_COUNT=20 TE_HIPBLASLT_TUNING_ALGO_COUNT=1000

thanks! will definitely try that. tuning_algo_count=1000 lol, i guess we probably need #67 to store the tuning results?