pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.57k stars 508 forks source link

8 (or 2 more) X A100 GPUs Model Output is Garbled and Failure to Terminate the Program Properly (One GPU is Correct) #64

Open qianghuangwhu opened 9 months ago

qianghuangwhu commented 9 months ago

Model Output is Garbled When using Multi A100 GPUs (8 (or 2 more) X A100) and Failure to Terminate the Program Properly

Environment

截屏2023-12-21 下午5 23 25 截屏2023-12-21 下午5 23 36

yifuwang commented 9 months ago

Hi @qianghuangwhu, can you try with PyTorch nightly build?

yifuwang commented 9 months ago

I was able to reproduce your issue on 2.1.2. You need to install the latest nightly build in order to enable torch.compile and tensor-parallel together.

qianghuangwhu commented 9 months ago

Hi @qianghuangwhu, can you try with PyTorch nightly build?

Thank you very much for your reply, I will try the PyTorch nightly version. Then update this issue.@yifuwang

qianghuangwhu commented 9 months ago

Hi @qianghuangwhu, can you try with PyTorch nightly build?

Hi, @yifuwang , I have check the source code the Pytorch 2.1.2, in the torch._inductor.config file, it not support torch._inductor.config.fx_graph_cache = True operation. Therefore Pytorch version 2.1.2 is not suitable. Then I have tried two Pytorch Nightly versions as follows:

GPU

Pytorch Nightly

Then I have tried two Pytorch Nightly versions:

Both these two Pytorch Nightly versions can make mode output the text without garbled characters when using Tensor Parallelism with 2 GPUs.

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=2 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth

截屏2023-12-23 下午7 51 32

Error

1. Pytorch Nightly 2.2.0 run successfully on more than 2 GPUs, but stuck at the sys.exit(main()) of torchrun when model output done!

Running Command

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth

截屏2023-12-23 下午7 42 55

2. Pytorch nightly 2.3.0 run successfully on 2 GPUs and can exit the torchrun properly, but error when using more than 2 GPUs

Running Command

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth
WARNING:__main__:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Loading model ...
Using int8 weight-only quantization!
Applying tensor parallel to model ...
Time to load model: 1.83 seconds
[rank1]:[W CUDAGraph.cpp:145] Warning: Waiting for pending NCCL work to finish before starting graph capture. (function operator())
[2023-12-23 11:32:05,914] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -6) local_rank: 0 (pid: 2708262) of binary: /software/qqq/Anaconda3/envs/gpt-fast/bin/python
Traceback (most recent call last):
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/site-packages/torch/distributed/run.py", line 816, in <module>
    main()
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/software/qqq/Anaconda3/envs/gpt-fast/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
========================================================
generate.py FAILED
--------------------------------------------------------
Failures:
[1]:
  time      : 2023-12-23_11:32:05
  host      : user
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 2708263)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2708263
[2]:
  time      : 2023-12-23_11:32:05
  host      : user
  rank      : 2 (local_rank: 2)
  exitcode  : -6 (pid: 2708264)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2708264
[3]:
  time      : 2023-12-23_11:32:05
  host      : user
  rank      : 3 (local_rank: 3)
  exitcode  : -6 (pid: 2708265)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2708265
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-12-23_11:32:05
  host      : user
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 2708262)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2708262
========================================================

The reason that make code stuck at the sys.exit(main()) of torchrun when model output done or error when using more than 2 GPUs may be is that I'm using 8 X A100 PCIE with P2P NVLink bridge, not SXM. Thus, I need to specificy the NCCL_P2P_LEVEL to NVL of NCCL to enable use P2P when GPUs are connected through NVLink NCCL. Thus, I use command:

NCCL_P2P_LEVEL=NVL ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth

but it does not work.

yifuwang commented 9 months ago
  1. Pytorch Nightly 2.2.0 run successfully on more than 2 GPUs, but stuck at the sys.exit(main()) of torchrun when model output done!

I suspect that this is an issue with ProcessGroupNCCL. Could you try running it without ENABLE_INTRA_NODE_COMM=1? If the problem persists, we can confirm that it's an issue related to ProcessGroupNCCL.

  1. Pytorch nightly 2.3.0 run successfully on 2 GPUs and can exit the torchrun properly, but error when using more than 2 GPUs

Can you try running the minimal example below to verify your NCCL setup?

import os

import torch
import torch.distributed as dist

if __name__ == "__main__":
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    pid = os.getpid()

    def log(msg) -> None:
        print(f"[rank {rank}, pid {pid}] {msg}")

    torch.cuda.set_device(f"cuda:{local_rank}")

    log("Initializing process group...")
    dist.init_process_group(backend="nccl")
    log("Process group initialization completed")

    log("Testing all_reduce...")
    t = torch.full((8, 8), rank, device="cuda")
    dist.all_reduce(t)
    assert t.eq(world_size * (world_size - 1) // 2).all()
    log("All_reduce completed")

Run it with torchrun --nproc_per_node=8 --monitor-interval=1 [name].py. You can also try the following for troubleshooting:

qianghuangwhu commented 9 months ago
  • nvidia-smi topo -m

Hi, @yifuwang, Thanks. I agree with you that it's an issue related to ProcessGroupNCCL of NCCL caused by my cards connection topology. I tried your suggestions but the problem persists.

1. the connection topology of my cards by running command nvidia-smi topo -m is:

截屏2023-12-24 下午4 35 34

As shown in Figure, Each two cards are connected via an NVLink bridge. So in my past experience, I need to specificy NCCL_P2P_LEVEL=NVL of NCCL when I using 8 cards for LLM inference. Otherwise, the program reports an error.

2. I can use the both two pytorch nightly versions to run your given code on 8 GPUs without any error.

code

import os

import torch
import torch.distributed as dist

if __name__ == "__main__":
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    pid = os.getpid()

    def log(msg) -> None:
        print(f"[rank {rank}, pid {pid}] {msg}")

    torch.cuda.set_device(f"cuda:{local_rank}")

    log("Initializing process group...")
    dist.init_process_group(backend="nccl")
    log("Process group initialization completed")

    log("Testing all_reduce...")
    t = torch.full((8, 8), rank, device="cuda")
    dist.all_reduce(t)
    assert t.eq(world_size * (world_size - 1) // 2).all()
    log("All_reduce completed")

Running command

The following three commands all work fine:

torchrun --nproc_per_node=8 --monitor-interval=1 nccl_test.py
NCCL_P2P_LEVEL=NVL torchrun --nproc_per_node=8 --monitor-interval=1 nccl_test.py
ENABLE_INTRA_NODE_COMM=1 NCCL_P2P_LEVEL=NVL torchrun --nproc_per_node=8 --monitor-interval=1 nccl_test.
py

Ouput

[2023-12-24 08:32:56,418] torch.distributed.run: [WARNING] 
[2023-12-24 08:32:56,418] torch.distributed.run: [WARNING] *****************************************
[2023-12-24 08:32:56,418] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2023-12-24 08:32:56,418] torch.distributed.run: [WARNING] *****************************************
[rank 5, pid 2827216] Initializing process group...
[rank 5, pid 2827216] Process group initialization completed
[rank 5, pid 2827216] Testing all_reduce...
[rank 3, pid 2827214] Initializing process group...
[rank 4, pid 2827215] Initializing process group...
[rank 3, pid 2827214] Process group initialization completed
[rank 3, pid 2827214] Testing all_reduce...
[rank 6, pid 2827217] Initializing process group...
[rank 4, pid 2827215] Process group initialization completed
[rank 4, pid 2827215] Testing all_reduce...
[rank 7, pid 2827218] Initializing process group...
[rank 6, pid 2827217] Process group initialization completed
[rank 6, pid 2827217] Testing all_reduce...
[rank 7, pid 2827218] Process group initialization completed
[rank 7, pid 2827218] Testing all_reduce...
[rank 2, pid 2827213] Initializing process group...
[rank 0, pid 2827211] Initializing process group...
[rank 2, pid 2827213] Process group initialization completed
[rank 2, pid 2827213] Testing all_reduce...
[rank 0, pid 2827211] Process group initialization completed
[rank 0, pid 2827211] Testing all_reduce...
[rank 1, pid 2827212] Initializing process group...
[rank 1, pid 2827212] Process group initialization completed
[rank 1, pid 2827212] Testing all_reduce...
[rank 7, pid 2827218] All_reduce completed
[rank 5, pid 2827216] All_reduce completed
[rank 3, pid 2827214] All_reduce completed
[rank 1, pid 2827212] All_reduce completed
[rank 4, pid 2827215] All_reduce completed
[rank 2, pid 2827213] All_reduce completed
[rank 6, pid 2827217] All_reduce completed
[rank 0, pid 2827211] All_reduce completed

3. Another thing worth mentioning is that if I don't use the --compile setting, all the run commands work fine on 8 GPUs with a low inference speed, using PyTorch Nightly.

ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py  --checkpoint_path checkpoints/$MODEL_REPO/model.pth
NCCL_P2P_LEVEL=NVL ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py  --checkpoint_path checkpoints/$MODEL_REPO/model.pth