Closed liuhatry closed 2 months ago
Hi @liuhatry -- I tested PR #986 earlier today on 2 nodes of 8xH100s and confirmed that the examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
is working correctly for the following use cases:
tp_size = local_size = world_size = 8
(1 node with 1 TP group per node and no replicas)tp_size = 4
and local_size = world_size = 8
(1 node with 2 TP groups per node for a total of 2 model replicas)tp_size = local_size = 8
and world_size = 16
(2 nodes with 1 TP group per node for a total of 2 model replicas)tp_size = 4
, local_size = 8
and world_size = 16
(2 nodes 2 TP groups per node for a total of 4 model replicas)These cases work with both with and without UB_SKIPMC=1
.
Can test if this PR resolves your issue?
Hi @denera I tested examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py 1 node can work, 2 nodes failed.
H800 NVIDIA-SMI 535.161.08 Driver Version: 535.161.08 CUDA Version: 12.2 torch 2.1.1 te: 1.9.0.dev0+70111a3 I modified the script, because it cannot run on torch 2.1.1
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
# Initialize torch.distributed global process group and get DP/TP groups
dist.init_process_group(
backend="nccl",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
init_method=init_method,
#device_id=torch.device(f"cuda:{LOCAL_RANK}"),
)
node0: torchrun --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr ${MASTER_IP}--master_port 60002 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --num-replicas 2 node1: torchrun --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_addr ${MASTER_IP} --master_port 60002 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --num-replicas 2
LOG
File "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 175, in train add_ub(name, ub_cfg) File "/usr/local/python/lib/python3.8/site-packages/transformer_engine/pytorch/module/base.py", line 342, in initialize_ub te.module.base.initialize_ub(RuntimeError: File "/usr/local/python/lib/python3.8/site-packages/transformer_engine/pytorch/module/base.py", line 250, in add_ub add_ub(name, ub_cfg) File "/usr/local/python/lib/python3.8/site-packages/transformer_engine/pytorch/module/base.py", line 342, in initialize_ub /root/Transformer-Engine/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:584 in function register_user_buffer_collective: CUDA Error: invalid resource handle
Hi @liuhatry -- I updated PR #986 to prefer Gloo backend over NCCL whenever possible for bootstrapping Userbuffers. The application code still has to initialize NCCL process groups for TE modules, but this change eliminates the requirement for the device_id
argument for compatibility with older PyTorch versions.
I've also tested the example problem with init_method=f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
. This works for me with all the local/world size combinations above, but I would recommend setting the NCCL_SOCKET_IFNAME
to the correct network interface(s) just in case it fails to find the right one on its own.
Could you test if the latest changes resolve your issue?
Hi @denera I tested the new code, it still failed
NCCL_SOCKET_IFNAME=bond1 python3 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --num-replicas 1
!!! [NVTE] Bootstrapping Userbuffers with backend="gloo" !!! [NVTE] Number of physical nodes: 1 !!! [NVTE] Global ranks on node 0: [0, 1, 2, 3, 4, 5, 6, 7] !!! [UB] Create UbufP2PCommOverlap Communicator UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz MC initialized succesfully, window size = 549755813888
File "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 188, in train File "/usr/local/python/lib/python3.8/site-packages/transformer_engine/pytorch/module/base.py", line 334, in initialize_ub File "/usr/local/python/lib/python3.8/site-packages/transformer_engine/pytorch/module/base.py", line 258, in add_ub File "/usr/local/python/lib/python3.8/site-packages/transformer_engine/pytorch/module/base.py", line 305, in allgather_callback torch.distributed.all_gather_into_tensor(global_tmp, local_tmp, group=ub_pgs[group]) File "/usr/local/python/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2897, in all_gather_into_tensor work = group._allgather_base(output_tensor, input_tensor) RuntimeError: no support for _allgather_base in Gloo process group
Hi @liuhatry — if the Gloo backend in PyTorch distributed can’t do an all-gather over processes on a single host CPU, that suggests something is broken outside of Transformer Engine.
Could you verify that you can perform the necessary collectives on host tensors with pure PyTorch (no TE code)?
For example:
import os
import torch
import torch.distributed as dist
# initialize default NCCL process group
world_rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
dist.init_process_group(backend="nccl", rank=world_rank, world_size=world_size)
# get a Gloo group for comms with host tensors
gloo_world = dist.new_group(backend="gloo")
localdata = torch.tensor([world_rank], dtype=torch.uint8, device="cpu")
globaldata = torch.empty(world_size, style=torch.uint8, device="cpu")
dist.all_gather_into_tensor(globaldata, localdata, gloo_world)
# verify result of all gather
reference = torch.tensor(list(range(world_size)), style=torch.uint8, device="cpu")
assert torch.eq(globaldata, reference)
The above is a simple representation of what happens when you run the comm+GEMM overlap example problem. The application initializes a default NCCL process group, and Transformer Engine then creates a Gloo process group for host tensor communication during Userbuffers bootstrapping.
If this does not run correctly, I would recommend working with your sysadmin to troubleshoot the machine you’re running on, and possibly reaching out to the PyTorch team as well for their feedback.
Hi @denera your example code cannot run as before I have checked the torch code, it indicates:
The Gloo backend does not support this API.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L3392
File "gloo.py", line 14, in
Hi @liuhatry -- you're correct, Gloo supports all_gather()
but not all_gather_into_tensor()
. Can you confirm that the following snippet works?
import os
import socket
import torch
import torch.distributed as dist
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
MASTER_ADDR = str(os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())))
MASTER_PORT = str(os.getenv("MASTER_PORT", "1234"))
BOOTSTRAP_BACKEND = str(os.getenv("BOOTSTRAP_BACKEND", "gloo")).lower()
BOOTSTRAP_DEVICE = "cuda" if BOOTSTRAP_BACKEND == "nccl" else "cpu"
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(backend="nccl",
init_method=f"tcp://{MASTER_ADDR}:{MASTER_PORT}",
rank=WORLD_RANK,
world_size=WORLD_SIZE)
bootstrap_world = dist.new_group(backend=BOOTSTRAP_BACKEND)
localdata = torch.tensor([WORLD_RANK], dtype=torch.uint8, device=BOOTSTRAP_DEVICE)
globaldata = torch.empty(WORLD_SIZE, dtype=torch.uint8, device=BOOTSTRAP_DEVICE)
dist.all_gather(list(globaldata.chunk(WORLD_SIZE)), localdata, bootstrap_world)
reference = torch.tensor(list(range(WORLD_SIZE)), dtype=torch.uint8, device=BOOTSTRAP_DEVICE)
assert torch.eq(globaldata, reference)
In order to be able to use comm+GEMM overlap, your platform needs to be able to run this code snippet with BOOSTRAP_BACKEND
set to "gloo"
, "mpi"
, or "nccl"
.
If you can get this running, then the examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
example should also work with te.module.base.initialize_ub(..., bootstrap_backend=<backend>)
set to the same BOOTSTRAP_BACKEND
that made the code snippet above work on your node. I also modified the example to make this easier to control with the --bootstrap-backend <backend>
argument.
Hi @denera I can run your snippet, but cannot run ln_mlp_with_overlap.py with two nodes.
export BOOTSTRAP_BACKEND=nccl torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 snippet.py
export GLOO_SOCKET_IFNAME=bond1 export BOOTSTRAP_BACKEND=gloo torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 snippet.py
export BOOTSTRAP_BACKEND=nccl torchrun --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 snippet.py torchrun --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_addr $MASTER_ADDR --master_port 60000 snippet.py
export GLOO_SOCKET_IFNAME=bond1 export BOOTSTRAP_BACKEND=gloo torchrun --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 snippet.py torchrun --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_addr $MASTER_ADDR --master_port 60000 snippet.py
export BOOTSTRAP_BACKEND=nccl orchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr $MATER_ADDR --master_port 60002 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend nccl --verbose
export GLOO_SOCKET_IFNAME=bond1 export BOOTSTRAP_BACKEND=gloo torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr $MATER_ADDR --master_port 60002 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend gloo --verbose
torchrun --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend nccl torchrun --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_addr $MASTER_ADDR --master_port 60000 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend nccl
!!! [NVTE] Number of physical nodes: 1 !!! [NVTE] Global ranks on node 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] !!! [UB] Create UbufP2PCommOverlap Communicator UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz UDS: Sending data over socket /tmp/ub-ipc-socket-8-deadcafeb000 failed : Connection refused (111) [E ProcessGroupNCCL.cpp:475] [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=12, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=1800000) ran for 1800839 milliseconds before timing out.
export GLOO_SOCKET_IFNAME=bond1 torchrun --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr $MASTER_ADDR --master_port 60000 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend gloo torchrun --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_addr $MASTER_ADDR --master_port 60000 examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py --num-iters=1000 --tcp-init --bootstrap-backend gloo
!!! [NVTE] Bootstrapping Userbuffers with backend="gloo"
!!! [NVTE] Number of physical nodes: 1
!!! [NVTE] Global ranks on node 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
!!! [UB] Create UbufP2PCommOverlap Communicator
UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
UDS: Sending data over socket /tmp/ub-ipc-socket-8-deadcafeb000 failed : Connection refused (111)
Traceback (most recent call last):
File "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 277, in
I checked the code(https://github.com/denera/TransformerEngine/blob/userbuffers-missing-data-parallel-pg/transformer_engine/pytorch/module/base.py#L128), found socket.gethostname() return the same result in my env, and the local_size is 16. (Pdb) hostnames ['TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site']
The UDS (Unix Domain Socket) error you’re seeing is coming from the CUDA Multicast handle initialization.
Userbuffers bootstrapping needs to communicate CUDA Multicast handles between processes, but these handles are POSIX file descriptors that have to be communicated over Unix Domain Sockets in order for the kernel to reconstruct the descriptors correctly on every process. Trying to do this with comm libraries like MPI or NCCL mangles the descriptors and prevents processes from importing each others’ Multicast handles. The code under transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h/cc
is what handles these sends/recvs over the domain sockets.
It looks like these Unix Domain Sockets aren’t working correctly on your nodes. Are there any limitation's on your node(s) or permission issues that may be causing this? I will also try to provide a minimal C++ tester to possibly help diagnose it without TE in the mix.
In the meantime, please disable Multicast with UB_SKIPMC=1
. If the snippet worked for you, I don’t see why this wouldn’t, as long as you’re initializing the default process group in the same way.
Hi @denera run with UB_SKIPMC=1 will also fail:
!!! [NVTE] Bootstrapping Userbuffers with backend="gloo"
!!! [NVTE] Number of physical nodes: 1
!!! [NVTE] Global ranks on node 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
!!! [UB] Create UbufP2PCommOverlap Communicator
UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
MC NOT initialized and used
Traceback (most recent call last): File "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 277, in
Revisiting an issue from earlier:
I checked the code(https://github.com/denera/TransformerEngine/blob/userbuffers-missing-data-parallel-pg/transformer_engine/pytorch/module/base.py#L128), found socket.gethostname() return the same result in my env, and the local_size is 16. (Pdb) hostnames ['TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site', 'TENCENT64.site']
I'm guessing this is a consequence of a containerized cluster environment like Kubernetes, correct? The nodes are probably reachable by IP address but not by hostname.
Can you try replacing base.py lines 127-128 with the following?
hostname = socket.gethostname()
ifname = os.getenv("NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME",
os.getenv("GLOO_SOCKET_IFNAME")))
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(),
0x8915,
struct.pack('256s', ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(world_size)]
This will attempt to construct a list of global ranks on each physical node via the IP address on the specified network interface. You already run with NCCL_SOCKET_IFNAME
and GLOO_SOCKET_IFNAME
in your environment, so it should be able to pluck out the IP address from the correct interface here and hopefully find the correct distribution of global ranks.
Edit: I updated PR #986 with this change and it should automatically kick in whenever you run with NCCL_SOCKET_IFNAME
or GLOO_SOCKET_IFNAME
set to a network interface in the environment.
Hi @denera Thanks for your reply.
Because my torch version is 2.1, when set num_replicas=2, the example will fail: File "examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py", line 175, in train AttributeError: module 'torch.distributed' has no attribute 'device_mesh'
Can you update the example to fix the problem?
Hi @liuhatry -- I recently merged PR986 into TE/main after confirming that it is resolving multi-node issues for us in NeMo and Mcore. These changes also update the example problem to no longer use device mesh to handle replicas on a single node run, so it should be able to support older PyTorch versions.
Could you please test TE/main and let me know if it resolves the issue for you?
Hi @denera, I met a new error when run in two nodes, the intra node barrier will hang:
TENCENT64:860293:861393 [0] bootstrap.cc:150 NCCL WARN Bootstrap Root : rank 5 of 8 ranks has already checked in
I modified the code(new_group()) both in the example and TE like this:
intra_node_group = None
for i in range(num_nodes):
ranks = list(range(i * local_size, (i + 1) * local_size))
group = torch.distributed.new_group(backend="nccl", ranks=ranks)
if world_rank in ranks:
intra_node_group = group
And now I can run the example successfully. The problem is the _newgroup function requires all processes enter this function, even if they are not going to be members of the group. new_group
Hi @denera, can you please help to confirm this issue, thks.
Hi @liuhatry -- I've reproduced the issue with TE/main but I'm able to resolve it by adding use_local_synchronization=True
to the group creation. This should eliminate the requirement for all ranks to enter the new_group()
call. Could you test if that works on your end?
I would also strongly recommend updating PyTorch and NCCL versions to the latest available and initialize the default NCCL process group in PyTorch with device_id=torch.device(f"cuda:{LOCAL_RANK})
. Although it is not mandatory, binding each rank to a single device like this will allow PyTorch to create non-blocking sub-communicators with ncclCommSplit()
instead of ncclCommInitRank()
, and avoid deadlocks during NCCL bootstrapping.
Hi @liuhatry -- I filed a PR with a fix for this issue. Could you confirm if it works for you? Thanks!
Hi @denera , PR #1087 can fix my problem, thanks.
@liuhatry -- thanks for confirming. I merged the PR so TE/main should now have all the fixes we've discussed here. Please feel free to close the issue here if everything is resolved on your end. Thanks!
Machine
NVIDIA-SMI 535.161.08 Driver Version: 535.161.08 CUDA Version: 12.2
SoftWare
torch 2.1.1 transformer-engine 1.9.0.dev0+56e0b35
Run Cmd:
deepspeed --hostfile hostfile --master_addr ${MASTER_IP} pretrain_gpt.py --deepspeed-activation-checkpointing --deepspeed_config=ds_config_gpt_test.json --deepspeed --tensor-model-parallel-size 4 --pipeline-model-parallel-size 1 ......
LOG
tp_size=8, world_size=8
!!! [UB] Create UbufP2PCommOverlap Communicator UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz NCCL_TOPO_AFFINITY set by environment to 0 MC initialized succesfully, window size = 549755813888 !!! [UBP2P] Register UBuf 1 !!! [UBP2P] Register UBuf 2 !!! [UBP2P] Register UBuf 3 !!! [UBP2P] Register UBuf 4 !!! [UB] Register UBuf 5 !!! [UB] Register UBuf 6 !!! [UB] Register UBuf 7 !!! [UB] Register UBuf 8 !!! [UB] Register UBuf 9 !!! [UB] Register UBuf 10
rank 7 | iteration 1/ 45776 | consumed samples: 128 | consumed tokens: 262144 | elapsed time this iteration (ms): 33222.7 |
tp_size=4, world_size=8
Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 '' !!! [UB] Create UbufP2PCommOverlap Communicator UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 '' Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 '' Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:223 ''
tp_size=4, world_size=8 UB_SKIPMC=1
!!! [UB] Create UbufP2PCommOverlap Communicator UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz MC NOT initialized and used NCCL_TOPO_AFFINITY set by environment to 0 NCCL_TOPO_AFFINITY set by environment to 0 UB: warning region 1 size 40 MB registered without MC access !!! [UBP2P] Register UBuf 1 Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 '' Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 '' Failed, NCCL errorTransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 '' UB: warning region 2 size 40 MB registered without MC access !!! [UBP2P] Register UBuf 2 Failed, NCCL error TransformerEngine_official/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp:513 '