pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.61k stars 179 forks source link

[float8] DDP GPT1.5B Torch.compile dynamo error #1308

Open OrenLeung opened 1 week ago

OrenLeung commented 1 week ago

Hi Torch Team,

I am currently experimenting with native torch float8 distributed training using the delayed scaling recipe on GPT 1.5B with DDP at batch=12 seq=1024 on an HGX 8xH100 (700W H100 SXM 80G SKU).

Currently, I am running into a DDP + torch.compile + float8 bug. Without enabling torch.compile it don't run into this error. I have tried using #1306 as well as main@latest Attached below is a self contained reprod & the Error Trace.

Commands

python3 test.py --enable_compile=False
python3 test.py --enable_compile=True

Error Trace

    submod_compiler.run(*example_inputs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 175, in run
    raise RuntimeError(*e.args) from e
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
RuntimeError: val

While executing %submod_479 : [num_users=1] = call_module[target=submod_479](args = (%l_self_modules_tsfmr_blks_modules_47_modules_ffn_modules_2_parameters_weight_, %l_self_modules_tsfmr_blks_modules_47_modules_ffn_modules_2_buffers_fp8_amax_weight_), kwargs = {})
Original traceback:
None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Reprod Script

import torch
import torch.nn as nn
from torchao.float8 import (
    convert_to_float8_training,
    sync_float8_amax_and_scale_history,
    Float8LinearConfig,
    ScalingType,
    CastConfig,
)
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import fire
import torch.multiprocessing as mp
import os
from torch.distributed import init_process_group

class CausalSelfAttention(nn.Module):
    def __init__(self, d_embd, n_heads, **kwargs):
        super().__init__()
        self.d_head = d_embd // n_heads  # D
        self.attn_proj = nn.Linear(d_embd, 3*d_embd)
        self.out_proj = nn.Linear(d_embd, d_embd)

    def forward(self, x_BTE):
        qkv = self.attn_proj(x_BTE).split(x_BTE.size(-1), -1)
        split_attn_head = lambda z: z.unflatten(-1, [-1, self.d_head]).transpose(1, 2)
        q_BHTD, k_BHTD, v_BHTD = map(split_attn_head, qkv)
        o_BHTD = F.scaled_dot_product_attention(q_BHTD, k_BHTD, v_BHTD, dropout_p=0.0, is_causal=True)
        o_BTE = o_BHTD.transpose(1, 2).flatten(-2)
        y_BTE = self.out_proj(o_BTE)
        return y_BTE

class GPTBlock(nn.Module):
    def __init__(self, d_embd, **kwargs):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_embd)
        self.attn = CausalSelfAttention(d_embd, **kwargs)
        self.ffn_norm = nn.LayerNorm(d_embd)
        self.ffn = nn.Sequential(
            nn.Linear(d_embd, 4*d_embd),
            nn.GELU(),
            nn.Linear(4*d_embd, d_embd)
        )

    def forward(self, x_BTE):
        x_BTE = x_BTE + self.attn(self.attn_norm(x_BTE))
        y_BTE = x_BTE + self.ffn(self.ffn_norm(x_BTE))
        return y_BTE

class GPT(nn.Module):
    def __init__(self, vocab_size, max_seq_len, n_layers, d_embd, **kwargs):
        super().__init__()
        self.tok_embd = nn.Embedding(vocab_size, d_embd)
        self.pos_embd = nn.Embedding(max_seq_len, d_embd)
        self.tsfmr_blks = nn.ModuleList(GPTBlock(d_embd, **kwargs) for _ in range(n_layers))
        self.out_norm = nn.LayerNorm(d_embd)

    def forward(self, idx_BT, **kwargs):
        pos_T = torch.arange(idx_BT.size(1), dtype=torch.int64, device=idx_BT.device)
        x_BTE = self.tok_embd(idx_BT) + self.pos_embd(pos_T).unsqueeze(0)

        for tsfmr_blk in self.tsfmr_blks:
            x_BTE = tsfmr_blk(x_BTE)

        x_BTE = self.out_norm(x_BTE)
        logits_BTV = x_BTE @ self.tok_embd.weight.T  # Weight tying

        return logits_BTV

def main(enable_compile=True):

    train_args = (enable_compile,)

    mp.spawn(train, train_args, nprocs=8)

def train(rank, enable_compile=True):
    world_size = 8
    # configure delayed scaling
    config = Float8LinearConfig(
        cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
        cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
        cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
        # enable_amax_init=False,  # only needed for autocast + compile + FSDP +  float8 delayed
        # enable_pre_and_post_forward=False  # only needed for autocast + compile + FSDP +  float8 delayed
)   

    torch.manual_seed(3985)
    os.environ.update({'MASTER_ADDR': 'localhost', 'MASTER_PORT': '30985'})
    torch.cuda.set_device(rank)
    init_process_group(backend='nccl', rank=rank, world_size=world_size)

    # GPT 1.5B
    cfg_json = {
        "n_layers": 48,
        "n_heads": 25,
        "d_embd": 1600,
        "max_seq_len": 1024,
        "vocab_size": 50304,
        "arch_name": "gpt"
    }
    model = GPT(**cfg_json).to(rank)

    N = sum(p.numel() for p in model.parameters())  # get param count

    flops_per_iter = 6 * N * 12 * 1024

    optimizer = torch.optim.AdamW(model.parameters(), fused=True)

    convert_to_float8_training(model, config=config)

    model = DDP(model, gradient_as_bucket_view=True)
    if enable_compile:
        model = torch.compile(model)

    for step_idx in range(100):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        input_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to(rank)
        label_BT = torch.randint(50304, [12, 1024], dtype=torch.int64).to(rank)

        start.record()
        with torch.amp.autocast('cuda', torch.bfloat16):
            logits_BTV = model(input_BT)
            loss = F.cross_entropy(logits_BTV.flatten(0, 1), label_BT.flatten())

        loss.backward()

        sync_float8_amax_and_scale_history(model)

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        end.record()

        torch.cuda.synchronize()
        t = start.elapsed_time(end) / 1e3
        flops_per_sec = flops_per_iter / t
        print(f"finish {step_idx} step: {(flops_per_sec/1e12):.2f} TFLOP/s")

if __name__ == "__main__":
    fire.Fire(main)

Torch Versions

pip list | grep torch
pytorch-triton               3.1.0+cf34004b8a
torch                        2.6.0.dev20241118+cu124
torch-tb-profiler            0.4.3
torchao                      0.7.0+git4402195e       $PATH/ao
vkuzo commented 6 days ago

Hi @OrenLeung , I also repro this. We haven't worked on enabling float8 + compile + DDP yet as we found that FSDP is significantly more common in jobs which are large enough to benefit from float8 training. Wondering if you are open to FSDP with NO_SHARD instead of DDP? Context: https://discuss.pytorch.org/t/difference-between-ddp-vs-fsdp-no-shard/209729