NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

tp_overlap need tensor parallel is equal world size ? #966

Open kuangdao opened 3 months ago

kuangdao commented 3 months ago

i want set tp size = 2 and the global world size = 2

the code is :


import os
import sys
import subprocess
import argparse

import torch
import torch.distributed as dist

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

def parse_args(argv=None, namespace=None):
    parser = argparse.ArgumentParser(
        description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
    )
    parser.add_argument(
        "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
    )
    parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
    parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
    parser.add_argument(
        "-n", "--num-heads", type=int, default=64, help="Number of attention heads."
    )
    parser.add_argument(
        "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
    )
    parser.add_argument(
        "--mlp-expansion-factor",
        type=int,
        default=4,
        help="MLP block intermediate size as a factor of hidden dimension.",
    )
    parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
    parser.add_argument(
        "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
    )
    parser.add_argument(
        "--no-comm-overlap",
        action="store_true",
        default=False,
        help="Disable the comm+GEMM overlap.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", default=False)
    return parser.parse_args(argv, namespace)

def train(opts):
    WORLD_RANK = int(os.getenv("RANK"))
    WORLD_SIZE = int(os.getenv("WORLD_SIZE"))

    def dist_print(msg, end="\n", all_ranks=False):
        if WORLD_RANK == 0 or all_ranks:
            print(f"[RANK-{WORLD_RANK}] {msg}", end=end)

    torch.cuda.set_device(WORLD_RANK)
    torch.manual_seed(opts.seed + WORLD_RANK)
    torch.cuda.manual_seed(opts.seed + WORLD_RANK)

    dist.init_process_group(
        backend="nccl",
        rank=WORLD_RANK,
        world_size=WORLD_SIZE,
        device_id=torch.device(f"cuda:{WORLD_RANK}"),
    )

    tp_group_0 = dist.new_group([0, 1],backend="nccl")
    tp_group_1 = dist.new_group([2, 3],backend="nccl")
    tp_group_2 = dist.new_group([4, 5],backend="nccl")
    tp_group_3 = dist.new_group([6, 7],backend="nccl")

    if WORLD_RANK in [0, 1]:
        tp_group = tp_group_0
    elif WORLD_RANK in [2, 3]:
        tp_group = tp_group_1
    elif WORLD_RANK in [4, 5]:
        tp_group = tp_group_2
    elif WORLD_RANK in [6, 7]:
        tp_group = tp_group_3

    tensor = torch.ones([2, 2]).cuda() * WORLD_RANK
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=tp_group)

    print("after allreduce is : {}".format(tensor))

    tp_size = dist.get_world_size(tp_group)

    ag_cfg = {  
        "method": "ring_exchange",
        "num_splits": 8,
        "num_sm": 1,
        "set_sm_margin": False,
    }
    rs_cfg = {  
        "method": "ring_exchange",
        "num_splits": 4,
        "num_sm": 1,
        "set_sm_margin": True,
    }
    hidden_size = opts.num_heads * opts.head_dim
    batched_size = opts.seq_length * opts.batch_size

    print("batched_size is : {}".format(batched_size))

    if not opts.no_comm_overlap:
        te.initialize_ub(
            [batched_size, hidden_size],
            tp_group,
            use_fp8=opts.fp8,
            dtype=torch.bfloat16,
            ub_cfgs={
                "fc1_fprop": ag_cfg,
                "fc1_dgrad": rs_cfg,
                "fc2_fprop": rs_cfg,
                "fc2_dgrad": ag_cfg,
            },
        )

    model = te.LayerNormMLP(
        hidden_size,
        opts.mlp_expansion_factor * hidden_size,
        params_dtype=torch.bfloat16,
        device="cuda",
        tp_group=tp_group,
        tp_size=tp_size,
        set_parallel_mode=True,
        sequence_parallel=True,  
        seq_length=opts.seq_length,
        micro_batch_size=opts.batch_size,
        ub_overlap_rs_dgrad=not opts.no_comm_overlap,
        ub_overlap_rs=not opts.no_comm_overlap,
        ub_overlap_ag=not opts.no_comm_overlap,
    )

    optim = torch.optim.Adam(model.parameters(), lr=0.0001)

    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    for i in range(opts.num_iters):
        dist_print(f"Iter {i+1}", all_ranks=opts.verbose)

        dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
        x = torch.rand(
            (opts.seq_length // tp_size, opts.batch_size, hidden_size),
            dtype=torch.bfloat16,
            device="cuda",
            requires_grad=True,
        )

        dist_print("|-- Forward pass", all_ranks=opts.verbose)
        with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
            y = model(x)
            dist_print("|-- Compute loss", all_ranks=opts.verbose)
            loss = y.flatten().sum()

        dist_print("|-- Backward pass", all_ranks=opts.verbose)
        loss.backward()

        dist_print("|-- Optimizer step", all_ranks=opts.verbose)
        optim.step()

    te.destroy_ub()
    dist.destroy_process_group()

if __name__ == "__main__":
    if "TORCHELASTIC_RUN_ID" in os.environ.keys():
        args = parse_args()
        train(args)
    else:
        subprocess.run(
            ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
            env=os.environ,
            check=True,
        )
    os._exit(0)

and i run with torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_sub_group.py

the error is :

企业微信截图_f2d656f8-4940-4441-b4f3-066153c1117c

the commit id of TransformerEngine is 4a4f05dadf7032ff2f4c0780d9adcde77878c7b1

and i use the docker image is nvcr.io/nvidia/nemo:24.05

timmoon10 commented 3 months ago

The tensor parallel group can be a subset of the world group. We frequently split the world group into orthogonal tensor-parallel, data-parallel, and pipeline-parallel groups.

Based on the error message, it looks like there's an error when NCCL is initializing IPC communicators: https://github.com/NVIDIA/TransformerEngine/blob/4a4f05dadf7032ff2f4c0780d9adcde77878c7b1/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp#L501 To get more information, can you set NCCL_DEBUG=WARN in the environment?

kuangdao commented 3 months ago

i have set export NCCL_DEBUG=WARN and there is no additional message

企业微信截图_84f6cf9d-47de-46f8-a314-aeac88cf9a0c
denera commented 3 months ago

@kuangdao TE in general supports TP size < world size, but the comm+GEMM overlap has some unique restrictions. The underlying device-to-device comms code currently assumes TP size == world size. You may be able to get around this limitation by running with UB_SKIPMC=1, but this leverages CUDA IPC Handles instead of CUDA Multicast so it may not be as performant.

As a disclaimer, comm+GEMM overlap is currently an experimental and somewhat fragile feature that is not yet fully supported in TE under all circumstances (and intentionally undocumented). That will change in the near future, as we improve the underlying device-to-device comms code and test it more rigorously on different platforms.

kuangdao commented 3 months ago

thanks, i know, i think comm+GEMM overlap is outstanding job, and i hope more documents such as design and Implementation will be give.

denera commented 1 month ago

@kuangdao -- we merged some changes to comm+GEMM overlap in the last month specifically to address multi-node mixed DP/TP use-cases. This feature is still restricted to tp_size <= local_size where local_size is the # of GPUs in a single NVLink domain (currently a single physical node of max 8 GPUs), but it now functions correctly with model replication across node boundaries. Could you test again and confirm if this works for your use case?