pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.63k stars 181 forks source link

[float8] FP8 GPT 1.5B Delayed Scaling 2x Slower than BF16 #1297

Closed OrenLeung closed 1 week ago

OrenLeung commented 2 weeks ago

Hi Torch Team,

I am currently experimenting with native torch float8 training & comparing it to the Transformer Engine using the delayed scaling recipe on GPT 1.5B at batch=12 seq=1024 on 700W H100 SXM 80G SKU.

I see that fp8 transformer engine provides slight perf include compared to autocast bf16 but unfortunately torchao.float8 is almost 2x slower. I attempted to improve performance by trying to enable fp8 & using bf16 autocast at the same time but unfortunately I ran into ValueError: All layers must have the same last seen input_dtype, got {torch.float32, torch.bfloat16} error. enabling fp8 & using bf16 autocast is something that TE does but not sure if it is needed for torchao.

Can you provide some guidance on how to improve performance on torchao.float8?

Thanks!

BF16 Autocast: 493.17 TFLOP/s
FP8 TE: 501.2 TFLOP/s
torchao.float8: 240.67 TFLOP/s 

Reprod Script

import torch
import torch.nn as nn
from torchao.float8 import (
    convert_to_float8_training,
    sync_float8_amax_and_scale_history,
    Float8LinearConfig,
    ScalingType,
    CastConfig,
)
import torch.nn.functional as F
import fire

class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)

    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT, **kwargs):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

# configure delayed scaling
config = Float8LinearConfig(
    cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
    cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
    # enable_amax_init=False,  # only needed for autocast + compile + FSDP +  float8 delayed
    # enable_pre_and_post_forward=False  # only needed for autocast + compile + FSDP +  float8 delayed
)

def main(enable_fp8=True):
    torch.manual_seed(3985)
    torch.cuda.set_device(0)

    # GPT 1.5B
    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
        "arch_name": "gpt"
    }
    model = GPT(**cfg_json).to('cuda:0')

    N = sum(p.numel() for p in model.parameters())  # get param count

    flops_per_iter = 6 * N * 16 * 1024

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)

    if enable_fp8:
        convert_to_float8_training(model, config=config)

    model = torch.compile(model)

    for step_idx in range(100):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        input_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')
        label_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to('cuda:0')

        start.record()
        if not enable_fp8:
            with torch.amp.autocast('cuda', torch.bfloat16):
                logits_BTV = model(input_BT)
                loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        else:
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())
        loss.backward()
        if enable_fp8:
            sync_float8_amax_and_scale_history(model)

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        end.record()

        torch.cuda.synchronize()
        t = start.elapsed_time(end) / 1e3
        flops_per_sec = flops_per_iter / t
        print(f"finish {step_idx} step: {(flops_per_sec/1e12):.2f} TFLOP/s")

if __name__ == "__main__":
    fire.Fire(main)

Dependencies

$ pip list | grep torch
pytorch-triton               3.1.0+cf34004b8a
torch                        2.6.0.dev20241030+cu124
torch-tb-profiler            0.4.3
torchao                      0.7.0.dev20241112+cu121
vkuzo commented 1 week ago

Hi @OrenLeung , thanks for the repro! This looks like a bug in how we handle delayed scaling + autocast, let me take a look.

vkuzo commented 1 week ago

thanks again for the report, https://github.com/pytorch/ao/pull/1306 should fix this. With that PR on my H100 machine:

OrenLeung commented 1 week ago

@vkuzo Thanks for the quick fix!

I am guessing you did your benchmark the 500W h100 version?

I can confirm the fix using #1306 ! I am seeing the following:

vkuzo commented 1 week ago

I am guessing you did your benchmark the 500W h100 version?

Yes, that's correct.

vkuzo commented 1 week ago

closing since the fix landed