ROCm / TransformerEngine

Other
13 stars 3 forks source link

[FSDP 8xMI300X]: LLama3 70B 4 Layer Proxy Model GPU Core Dumps #78

Open OrenLeung opened 1 month ago

OrenLeung commented 1 month ago

Problem Description

On Llama3 70B Proxy Model, the training stalls & gpucore dumps. The gpucore dumps are 41GByte per GPU thus i am unable to send it. Probably easier for yall to reprod this error on your end to get the gpucore dump.

I have verified on H100, te fp8 for llama3 70B fsdp 4 layer model model trains perfectly fine with a 38% TFLOP/s/GPU increase compared to bf16 torch.compile

cc: @hliuca

image

Operating System

Ubuntu

CPU

AMD CPU

GPU

MI300X

ROCm Version

ROCm 6.2.0

ROCm Component

No response

Steps to Reproduce

Docker Image

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

TE install Instructions (done inside docker container)

cd /workspace
git clone --recursive https://github.com/ROCm/TransformerEngine.git
export NVTE_USE_HIPBLASLT=1
export NVTE_FRAMEWORK=pytorch
export PYTORCH_ROCM_ARCH=gfx942
cd TransformerEngine && pip install .
cd /workspace/llm-train-bench

Reprod Script

from dataclasses import asdict
from typing import Optional
from pydantic.dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

# DDP
import os
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group

# FSDP
from functools import partial
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from tqdm import tqdm

# FP8 Transformer Engine
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

def dprint(rank, *args, **kwargs):
    if rank == 0:
        print(*args, **kwargs)

class DummyDataset(Dataset):
    def __init__(self, vocab_size, max_seq_len, ds_len):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.ds_len = ds_len

    def __getitem__(self, idx):
        input_T = torch.randint(self.vocab_size, [self.max_seq_len], dtype=torch.int64)
        label_T = torch.cat([input_T[:-1], torch.randint(self.vocab_size, [1])])
        return input_T, label_T

    def __len__(self):
        return self.ds_len

def create_distributed_data_loader(rank, world_size, bsz, n_steps, cfg_m):
    dataset = DummyDataset(cfg_m.vocab_size, cfg_m.max_seq_len, bsz*n_steps)
    data_loader = DataLoader(
        dataset, batch_size=bsz,
        num_workers=8, pin_memory=True, shuffle=False,
        sampler=DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
    )

    return data_loader

def configure_train_loop(data_loader, cfg_m, bsz, rank=0):
    if rank != 0:
        for step_idx, data_batch in enumerate(data_loader):
            yield step_idx, data_batch
        return

    flops_per_iter = cfg_m.flops_per_token * (bsz * cfg_m.max_seq_len)

    flops_promised = 2610e12

    with tqdm(total=len(data_loader)) as pbar:
        for step_idx, data_batch in enumerate(data_loader):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()

            yield step_idx, data_batch

            end.record()
            torch.cuda.synchronize()

            t = start.elapsed_time(end) / 1e3
            flops_per_sec = flops_per_iter / t
            mfu = flops_per_sec / flops_promised

            pbar.set_description(f'[rank0]  {(flops_per_sec/1e12):.2f} TFLOP/s  MFU={mfu:.2%}')
            pbar.update()

@dataclass
class LLaMAConfig:
    n_layers: int    # L
    n_heads: int     # H
    n_kv_heads: int  # J
    d_embd: int      # E
    max_seq_len: int # T
    vocab_size: int  # V
    ffn_mult: float
    ffn_factor: int
    rope_base: float
    norm_eps: float
    d_hid: int = Optional[int] # K
    arch_name: str = 'llama'

    def estimate_flops_per_token(self, model, bsz, rank=0):
        head_dim = self.d_embd // self.n_heads
        N = sum(p.numel() for p in model.parameters())  # get param count

        if rank == 0:
            print(f"Number of parameters: {N/1e9:.2f}B")    # print number of billion parameters 

        self.flops_per_token = 6 * N + 12 * self.n_layers * self.n_heads * head_dim * self.max_seq_len

    def __post_init__(self):
        assert self.d_embd % self.n_heads == 0, 'd_embd must be a multiple of n_heads.'
        assert self.d_embd % self.n_kv_heads == 0, 'd_embd must be a multiple of n_kv_heads.'
        assert self.n_kv_heads <= self.n_heads, 'n_kv_heads must not be larger than n_heads.'

        # FFN hidden dimension
        d_hid = int((4 * self.d_embd) * 2 / 3)
        d_hid = int(d_hid * self.ffn_mult)
        self.d_hid = self.ffn_factor * ((d_hid + self.ffn_factor - 1) // self.ffn_factor)                

class Fp8LLaMA(nn.Module):
    def __init__(self, vocab_size, d_embd, n_layers, n_heads, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.tsfmr_blks = nn.ModuleList(
            Fp8LLaMABlock(d_embd, n_heads=n_heads, **kwargs) for _ in range(n_layers)
        )
        self.norm_lm_head = te.LayerNormLinear(
            d_embd, vocab_size, bias=False,
            normalization='RMSNorm', eps=kwargs['norm_eps']
        )

        # Reference: https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json
        freq_cis_TE = te.attention.RotaryPositionEmbedding(d_embd//n_heads)(max_seq_len=131072)
        self.register_buffer('freq_cis_TE', freq_cis_TE.to(torch.bfloat16))

    def forward(self, idx_BT, is_first_microbatch):
        x_BTE = self.tok_embd(idx_BT)
        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE, rotary_pos_emb=self.freq_cis_TE, is_first_microbatch=is_first_microbatch)
        logits_BTV = self.norm_lm_head(x_BTE)
        return logits_BTV

class Fp8LLaMABlock(te.TransformerLayer):
    ''' Reference Implementation:
    https://github.com/NVIDIA/TransformerEngine/blob/55dcbb4b02f560d52dc1215a9de348b37487ee3d/docs/examples/te_llama/te_llama.py#L42
    '''
    def __init__(self, d_embd, d_hid, n_heads, n_kv_heads, norm_eps, **kwargs):
        super().__init__(
            hidden_size=d_embd,
            num_attention_heads=n_heads,
            num_gqa_groups=n_heads//n_kv_heads,
            fuse_qkv_params=True,
            attn_input_format='bshd',
            attention_dropout=0.0,
            normalization='RMSNorm',
            layernorm_epsilon=norm_eps,
            ffn_hidden_size=d_hid,
            bias=False,
            activation='swiglu',
            hidden_dropout=0.0
        )

def train(
    bsz: int = 10,
):

    torch.manual_seed(3985)
    world_size = torch.cuda.device_count()
    train_args = (
        world_size,
        bsz
    )
    try:
        mp.spawn(train_fsdp, train_args, nprocs=world_size)
    except:
        destroy_process_group()

def train_fsdp(
    rank, world_size, bsz
):
    # Construct process group
    os.environ.update({'MASTER_ADDR': 'localhost', 'MASTER_PORT': '30985'})
    torch.cuda.set_device(rank)
    init_process_group(backend='nccl', rank=rank, world_size=world_size)

    cfg = {
        "n_layers": 4,
        "n_heads": 64,
        "n_kv_heads": 8,
        "d_embd": 8192,
        "max_seq_len": 4096,
        "vocab_size": 128256,
        "ffn_mult": 1.3,
        "ffn_factor": 1024,
        "rope_base": 500000.0,
        "norm_eps": 1e-05,
        "d_hid": 28672,
        "arch_name": "llama"
    }

    use_fp8 = True
    grad_acc_steps = 8
    n_steps = 128*8
    # Configure training setup
    cfg_m, model_cls, blk_cls = LLaMAConfig(**cfg), Fp8LLaMA, Fp8LLaMABlock
    model = model_cls(**asdict(cfg_m)).to(rank)
    dprint(rank, f'Loaded {model_cls} model.', end=' ')
    cfg_m.estimate_flops_per_token(model, bsz, rank)  # Need to do before wrapping in FSDP

    data_loader = create_distributed_data_loader(rank, world_size, bsz, n_steps, cfg_m)
    optimizer = torch.optim.AdamW(model.parameters(), fused=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda t: 1.0)

    # FSDP
    model = FSDP(
        model,
        device_id=rank,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16
        ),
        auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls={blk_cls}),
        use_orig_params=True
    )
    dprint(rank, f'Created FSDP model')

    prepare_te_modules_for_fsdp(model)
    dprint(rank, 'Sharded TE modules for FSDP')

    # Training loop
    loop_iter = configure_train_loop(data_loader, cfg_m, bsz, rank)
    model.train()

    fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
    all_gpus = dist.new_group(backend='nccl')

    for step_idx, data_batch in loop_iter:
        input_BT, label_BT = map(lambda t: t.pin_memory().to(rank), data_batch)

        with torch.amp.autocast('cuda', torch.bfloat16):
            with te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
                weight_cache = use_fp8 and (step_idx % grad_acc_steps == 0)
                logits_BTV = model(input_BT, is_first_microbatch=weight_cache)
                loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
                loss /= grad_acc_steps

        loss.backward()

        if (step_idx + 1) % grad_acc_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

    dist.barrier()
    destroy_process_group()

if __name__ == '__main__':
    import fire
    fire.Fire(train)

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

hliuca commented 1 month ago

I will report this. Thanks Oren.

OrenLeung commented 1 month ago

Thanks @hliuca ,

For further context, On MI300X BF16 torch.compile nightly, i get the following preliminary results:

In the reprod script, it is batch size = 10, I have can confirm that batch size 12 also causes gpucore dump

OrenLeung commented 1 month ago

Interestingly when I do batch size 2 I do not gpu core dump but at this small of a batch size, the TFLOP/s/GPU is 491.22, which is 6% slower than bf16 at batch size 12. preliminary

batch size 2 command

 python ./train_fsdp_llama_70_reprod.py --bsz=2

batch size 2 command

python ./train_fsdp_llama_70_reprod.py --bsz=4

batch size 10 (batch size 10 is the default in the same reprod script above so no need for args)

python ./train_fsdp_llama_70_reprod.py 
wenchenvincent commented 1 month ago

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (https://github.com/ROCm/TransformerEngine/pull/66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU

OrenLeung commented 1 month ago

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU

hi @wenchenvincent ,

Thanks for looking into this! Do you have an estimated ETA on when #66 will be merged? Since this is such a big PR, I will probably have to wait till it hits the main branch before I re-test. Probably will wait till #69 & transpose_cast_opt branch merge too.

I was also wondering which Dockerfile image you are using as the base image to obtain these results? And is this base image publicly accessible?

From your results, it does seem like your fp8 has better results than mi300x bf16. We estimate that TCO of mi300x is 78% of an h100. So to get competitve perf per $ results vs h100, mi300x fp8 will probably need to hit 742.2 TFLOP/s/GPU.

Is there other PRs or thoughts you have that would potentially help improve performance of mi300x te fp8?

cc: @hliuca

Here is my preliminary numbers on this gh issue's model (llama3 70B 4 Layer Proxy):

wenchenvincent commented 1 month ago

@OrenLeung This issue was due to that our dev branch does not have all the recent optimizations on DDP and FSDP from NVTE yet. We have a PR in review that would be merged soon that could resolve this issue (#66). Here are the numbers that I got with this PR: 8xMI300X FP8 TE batch size 2: 572 TFLOP/s/GPU 8xMI300X FP8 TE batch size 10: 696 TFLOP/s/GPU

hi @wenchenvincent ,

Thanks for looking into this! Do you have an estimated ETA on when #66 will be merged? Since this is such a big PR, I will probably have to wait till it hits the main branch before I re-test. Probably will wait till #69 & transpose_cast_opt branch merge too.

I was also wondering which Dockerfile image you are using as the base image to obtain these results? And is this base image publicly accessible?

From your results, it does seem like your fp8 has better results than mi300x bf16. We estimate that TCO of mi300x is 78% of an h100. So to get competitve perf per $ results vs h100, mi300x fp8 will probably need to hit 742.2 TFLOP/s/GPU.

Is there other PRs or thoughts you have that would potentially help improve performance of mi300x te fp8?

cc: @hliuca

Here is my numbers on this gh issue's model (llama3 70B 4 Layer Proxy):

  • 8xMI300X BF16 batch size 8: 508 TFLOP/s/GPU
  • 8xMI300X BF16 batch size 10: 512.64 TFLOP/s/GPU
  • 8xMI300X BF16 batch size 12: 518.19 TFLOP/s/GPU
  • 8xMI300X BF16 batch size 14: OOM
  • 8xH100 BF16 batch size 2: 649.02 TFLOP/s/GPU
  • 8xH100 BF16 batch size 4: 687.13 TFLOP/s/GPU
  • 8xH100 TE FP8 batch size 2: 951.61 TFLOP/s/GPU
  • 8xH100 TE FP8 batch size 4: 759.99 TFLOP/s/GPU

@OrenLeung #66 only needs a few minor changes and the bottleneck for merging it was our CI capability... But I expect that it would be merged this week.

I was using the same docker image that you used for producing the numbers.

I haven't got a chance to dump the traces of this model run yet, but I suspect that it might also suffer from the issue with fp8 cast transpose and some fp8 GEMM might not be tuned yet. So potentially the fp8 cast transpose optimization and fp8 GEMM tuning would further improve the performance.

OrenLeung commented 1 month ago

Furthermore here is the preliminary H200 numbers. To be competitive with H200 on a perf per TCO basis, AMD needs to be at 910 TFLOP/s/GPU.

hliuca commented 1 month ago

Thank you Oren for providing H200 data. These data are very valuable and helpful. Our TE team and other teams are actively working on all the issues you have filed.

OrenLeung commented 1 month ago

Thank you Oren for providing H200 data. These data are very valuable and helpful. Our TE team and other teams are actively working on all the issues you have filed.

hi @hliuca ,

I am glad we were able to provide an optimization goal.

Please note that all of our H100 & H200 that we shared are preliminary and will probably improve too as I do tuning on them.

Also please note that we are benchmarking & evaluating AMD/Nvidia on other real world transformer models and real world GEMM training shapes that we have not shared with Nvidia or AMD to ensure that these patches to pytorch, te, hipblaslt, etc made are generalizable.

hliuca commented 1 month ago

Yes @OrenLeung totally understand. Thank you for driving us doing better job.

OrenLeung commented 4 weeks ago

After https://github.com/ROCm/TransformerEngine/pull/66 merged to main, I now get a prelimary number of 716.97 TFLOP/s/GPU on my internal codebase

After 32 Warmup: Mean TFLOP/s: 716.97 Mean MFU: 27.47%

Great work! @wenchenvincent !

I assume once triton transpose cast fused op & v3 ck attn merges, it will closer to H100's fp8 951.61 TFLOP/s/GPU

wenchenvincent commented 4 weeks ago

After #66 merged to main, I now get a prelimary number of 716.97 TFLOP/s/GPU on my internal codebase

After 32 Warmup: Mean TFLOP/s: 716.97 Mean MFU: 27.47%

Great work! @wenchenvincent !

I assume once triton transpose cast fused op & v3 ck attn merges, it will closer to H100's fp8 951.61 TFLOP/s/GPU

@OrenLeung Thank you!

@wangye805 had run this model on a different machine and he was getting 747 TFLOP/s. We're investigate why that system could give better performance and hope to make it reproducible.

Yeah, triton cast transpose should be give further improvement. And fp8 GEMM tuning in hipblasLt library and CK FA v3 should give more improvements. But for latter two, we will need to check the timeline internally.

OrenLeung commented 4 weeks ago

@wenchenvincent interesting that a different machine gives a different TFLOP/s.

Note that before step 16, the TFLOPs in the reprod script usually fluctuates (as it warms up and does grad accum every 8 steps)

In my internal codebase, I usually do warmup of 32 steps then take the mean over 50 steps to get an accurate measurement of what the realistic TFLOP/s would be.

wenchenvincent commented 4 weeks ago

@wenchenvincent interesting that a different machine gives a different TFLOP/s.

Note that before step 16, the TFLOPs in the reprod script usually fluctuates (as it warms up and does grad accum every 8 steps)

In my internal codebase, I usually do warmup of 32 steps then take the mean over 50 steps to get an accurate measurement of what the realistic TFLOP/s would be.

It could be that the other machine has the newer version of kernel driver. And there are some system config tuning that might impact performance as well: https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/system.html

OrenLeung commented 4 weeks ago

@wenchenvincent that is quite interesting tho my from understanding most of those knobs in the system tuning guide don't really effect text only transformer based models much since this class of models have very small DtoH and HtoD transfer and don't really use the CPU much. so tuning NUMA (NPS1, NPS4, etc.), etc doesn't really effect the performance.

I can see how those knobs will affect cpu dataloader heavy & heavy HtoD transfer models like image or video.

wenchenvincent commented 3 weeks ago

@wenchenvincent that is quite interesting tho my from understanding most of those knobs in the system tuning guide don't really effect text only transformer based models much since this class of models have very small DtoH and HtoD transfer and don't really use the CPU much. so tuning NUMA (NPS1, NPS4, etc.), etc doesn't really effect the performance.

I can see how those knobs will affect cpu dataloader heavy & heavy HtoD transfer models like image or video.

@OrenLeung Those knobs are for general MI300X system tuning. The most relevant knob to the GPU would be this one: https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html#deterministic-clock Sometimes, using the default frequency of 2100MHz for some workload would trigger PCC (Peak Current Control) event lowering the attainable GPU frequency.

wenchenvincent commented 3 weeks ago

Unfortunately, the machine that produced the better perf has been down in the past two days for maintenance and upgrade. Once it is up, we will continue to investigate why it could produce better numbers.

wenchenvincent commented 3 weeks ago

@OrenLeung Also, I think I might have forgotten to mention that we can use autotuning in TE to select the best performing kernels from hipBlasLt for specific GEMM size (if there are varieties of kernels for a specific gemm size): https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#gemm-tuning-with-hipblaslt

wenchenvincent commented 3 weeks ago

@OrenLeung Also, I think I might have forgotten to mention that we can use autotuning in TE to select the best performing kernels from hipBlasLt for specific GEMM size (if there are varieties of kernels for a specific gemm size): https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#gemm-tuning-with-hipblaslt

The perf number that I got was without autotuning though. Once we get the machine back up, we will try with autotuning to see how much we can get.

OrenLeung commented 3 weeks ago

@wenchenvincent nice! I also seen that there is an autotuning storage PR, what was the timeline for that? Such that we don't need to autotone for every run and can just cache the optimal gemm selection

wenchenvincent commented 3 weeks ago

@wenchenvincent nice! I also seen that there is an autotuning storage PR, what was the timeline for that? Such that we don't need to autotone for every run and can just cache the optimal gemm selection

@OrenLeung The PR is under review and we're looking to merge it end of this week or early next week.

wenchenvincent commented 1 week ago

@OrenLeung We have the optimized cast transpose Triton kernel merged in. And with that, I got the following improvement:

8xMI300X FP8 TE batch size 10: 701 TFLOP/s -> 751.88 TFLOP/s

One of my colleagues got better number like 795 TFLOP/s with different machines and different dockers. I will check to see if I can attain that to reproduce his numbers.

OrenLeung commented 1 week ago

hi @wenchenvincent thanks! can you send over the dockerfile?

hliuca commented 2 days ago

Hi @OrenLeung

image

Attached please find a dockerfile. I am working with dev teams to provide a final dockerfile in next few days. Meanwhile, if you like, you may try the following dockerfile, which provides nice perf. Thank you.

Dockerfile.rocm.txt