ROCm / TransformerEngine

Other
13 stars 3 forks source link

[DDP 8xMI300X] GPT2-1.5B FP8 is 25% slower than BF16 & OOMs on the same batch size #76

Open OrenLeung opened 1 month ago

OrenLeung commented 1 month ago

Problem Description

Even with NVTE_USE_HIPBLASLT=1 & Installing TE while inside the container instead of through Dockerfile as suggested by https://github.com/ROCm/TransformerEngine/issues/74#issuecomment-2414845971, FP8 is 25% slower than BF16. Furthermore, it even OOMs on the same batch size as what bf16 can fit. On Nvidia H100 Transformer Engine, usually I can even fit more batches than bf16 & it never OOMs on the same batch size.

The command to run this is python ./train_gpt_ddp_reprod.py using the reprod script & TE Install Instructions Below.

cc: @hliuca

preliminary Results TFLOP/s/GPU:

on this model with DDP, H100 saw an 16% increase in TFLOP/s/GPU from using FP8.

Operating System

Ubuntu

CPU

AMD CPU

GPU

MI300X

ROCm Version

ROCm 6.2.0

ROCm Component

No response

Steps to Reproduce

Docker Image

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

TE install Instructions (done inside docker container)

cd /workspace
git clone --recursive https://github.com/ROCm/TransformerEngine.git
export NVTE_USE_HIPBLASLT=1
export NVTE_FRAMEWORK=pytorch
export PYTORCH_ROCM_ARCH=gfx942
cd TransformerEngine && pip install .
cd /workspace/llm-train-bench

Reprod Script

from tqdm import tqdm
from dataclasses import asdict
from pydantic.dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist

# DDP
import os
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

# FP8 Transformer Engine
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

def dprint(rank, *args, **kwargs):
    if rank == 0:
        print(*args, **kwargs)

class DummyDataset(Dataset):
    def __init__(self, vocab_size, max_seq_len, ds_len):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.ds_len = ds_len

    def __getitem__(self, idx):
        input_T = torch.randint(self.vocab_size, [self.max_seq_len], dtype=torch.int64)
        label_T = torch.cat([input_T[:-1], torch.randint(self.vocab_size, [1])])
        return input_T, label_T

    def __len__(self):
        return self.ds_len

def create_distributed_data_loader(rank, world_size, bsz, n_steps, cfg_m):
    dataset = DummyDataset(cfg_m.vocab_size, cfg_m.max_seq_len, bsz*n_steps)
    data_loader = DataLoader(
        dataset, batch_size=bsz,
        num_workers=8, pin_memory=True, shuffle=False,
        sampler=DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
    )

    return data_loader

@dataclass
class GPTConfig:
    n_layers: int    # L
    n_heads: int     # H
    d_embd: int      # E
    max_seq_len: int = 1024
    vocab_size: int  = 50304 # V
    arch_name: str = 'gpt'

    def estimate_flops_per_token(self, model, bsz, rank=0):
        head_dim = self.d_embd // self.n_heads
        N = sum(p.numel() for p in model.parameters())  # get param count

        if rank == 0:
            print(f"Number of parameters: {N/1e9:.2f}B")    # print number of billion parameters 

        self.flops_per_token = 6 * N + 12 * self.n_layers * self.n_heads * head_dim * self.max_seq_len

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'

class Fp8GPTBlock(te.TransformerLayer):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__(
            d_embd,
            4*d_embd,
            n_heads,
            hidden_dropout=0.0,
            attention_dropout=0.0,
            layer_type='encoder',
            self_attn_mask_type='causal',
            normalization='LayerNorm',
            bias=True,
            activation='gelu',
            attn_input_format='bshd',
            fuse_qkv_params=True
        )

class Fp8GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(Fp8GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = te.LayerNorm(d_embd)

    def forward(self, idx_BT, is_first_microbatch):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE, is_first_microbatch=is_first_microbatch)

        # Couldn't fuse layer norm with linear due to weight tying
        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T

        return logits_BTV

def configure_train_loop(data_loader, cfg_m, bsz, rank=0):
    if rank != 0:
        for step_idx, data_batch in enumerate(data_loader):
            yield step_idx, data_batch
        return

    flops_per_iter = cfg_m.flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2610e12

    with tqdm(total=len(data_loader)) as pbar:
        for step_idx, data_batch in enumerate(data_loader):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            yield step_idx, data_batch

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            pbar.set_description(f'[rank0]  {(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')
            pbar.update()

def train(bsz = 28):

    torch.manual_seed(3985)
    world_size = torch.cuda.device_count()
    train_args = (
        world_size, bsz)

    try:
        mp.spawn(train_ddp, train_args, nprocs=world_size)
    except:
        destroy_process_group()

def train_ddp(
    rank, world_size, bsz
):
    # Construct process group
    os.environ.update({'MASTER_ADDR': 'localhost', 'MASTER_PORT': '30985'})
    torch.cuda.set_device(rank)
    init_process_group(backend='nccl', rank=rank, world_size=world_size)

    cfg = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
        "arch_name": "gpt"
    }

    use_fp8 = True
    grad_acc_steps = 8
    n_steps = 128*8

    # Configure training setup
    model_cls = Fp8GPT
    cfg_m = GPTConfig(**cfg)
    model = Fp8GPT(**asdict(cfg_m)).to(rank)
    dprint(rank, f'Loaded {model_cls} model.', end=' ')
    cfg_m.estimate_flops_per_token(model, bsz, rank)

    data_loader = create_distributed_data_loader(rank, world_size, bsz, n_steps, cfg_m)
    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    # DDP
    all_gpus = dist.new_group(backend='nccl')
    model = DDP(model, process_group=all_gpus, gradient_as_bucket_view=True)
    dprint(rank, f'Created DDP model')

    fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
    all_gpus = dist.new_group(backend='nccl')

    # Training loop
    loop_iter = configure_train_loop(data_loader, cfg_m, bsz, rank)
    model.train()

    for step_idx, data_batch in loop_iter:
        input_BT, label_BT = map(lambda t: t.pin_memory().to(rank), data_batch)

        # only sync grads when need doing opt.step()
        model.require_backward_grad_sync = (step_idx == grad_acc_steps - 1)

        with torch.amp.autocast('cuda', torch.bfloat16):
            with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
                weight_cache = use_fp8 and (step_idx % grad_acc_steps == 0)
                logits_BTV = model(input_BT, is_first_microbatch=weight_cache)
                loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
                loss /= grad_acc_steps
        loss.backward()

        if (step_idx + 1) % grad_acc_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

    destroy_process_group()

if __name__ == '__main__':
    import fire
    fire.Fire(train)

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

wenchenvincent commented 1 month ago

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (https://github.com/ROCm/TransformerEngine/pull/66). Here are the numbers that I got with this PR:

8xMI300X DDP FP8 TE (batch size 28): 315TFLOPs 8xMI300X DDP FP8 TE (batch size 32): 318TFLOPs 8xMI300X DDP FP8 TE (batch size 42): 324TFLOPs

OrenLeung commented 1 month ago

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR:

8xMI300X DDP FP8 TE (batch size 28): 315TFLOPs 8xMI300X DDP FP8 TE (batch size 32): 318TFLOPs 8xMI300X DDP FP8 TE (batch size 42): 324TFLOPs

Hi @wenchenvincent ,

Thanks for looking into this. This results look much better. Can you provide the Dockerfile instructions for me to reproduce these results?

To be competitive with H100 on a perf per TCO basis, MI300x needs to hit 398 TFLOP/s/GPU. Any other PRs or optimizations you have in the pipeline?

cc: @hliuca

Here is my Nvidia preliminary results for gpt2-1.5B fp8 full training:

Full Response in the llama3 70B proxy gh issue https://github.com/ROCm/TransformerEngine/issues/78#issuecomment-2418538437

OrenLeung commented 1 month ago

after #66 merged to main, i now get 322TFLOP/s/GPU on this model in our internal codebase

After 32 Warmup: Mean TFLOP/s: 322.79 Mean MFU: 12.37%

similar to @wenchenvincent 's TFLOP/s