Closed wbmc closed 11 months ago
Since you've create the PR, how did you test it? Specifically, what command did you run?
Since you've create the PR, how did you test it? Specifically, what command did you run?
Yes, I have tested on 2 gpu hosts. It needs to change a few codes of the test_gpu_multi_hosts.py
.
import logging
import os
import time
import torch
import torch.nn as nn
import torch.distributed as dist
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt_backend
from torch_xla._internal.pjrt import *
from torch_xla._internal import gpu, pjrt
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel
from torch_xla.distributed.fsdp.wrap import (
always_wrap_policy as always_wrap,)
from multiprocessing import Process
class MyLinear(nn.Linear):
"""Linear layer with deterministic reset_parameters for testing."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def reset_parameters(self, *args, **kwargs):
with torch.no_grad():
self.weight.fill_(1)
class MyModel(nn.Module):
def __init__(self, device):
super().__init__()
self.lin1 = MyLinear(2, 2, bias=False, device=device)
self.lin2 = MyLinear(2, 2, bias=False, device=device)
def forward(self, x):
return self.lin2(self.lin1(x))
def reset_parameters(self, *args, **kwargs):
for m in [self.lin1, self.lin2]:
if not isinstance(m, XlaFullyShardedDataParallel):
m.reset_parameters()
def forward():
with torch.no_grad():
device = xm.xla_device()
model = MyModel(device)
inp = torch.randn(10, 2, device=device)
logits = model(inp)
return logits
def _mp_fn(index, *args, **kwargs):
dist.init_process_group('xla', init_method='xla://')
logits = forward()
output_tensors = [
torch.zeros_like(logits, device=xm.xla_device())
for _ in range(int(os.environ['PJRT_WORLD_SIZE']))
]
# test all-gather
dist.all_gather(output_tensors, logits)
# test all-reduce
dist.all_reduce(logits)
xm.mark_step()
return None
def worker_fn(local_rank,
group_rank,
local_world_size,
world_size,
cuda_visible_devices='',
*args,
**kwargs):
os.environ[xenv.PJRT_LOCAL_RANK] = str(local_rank)
os.environ[xenv.PJRT_LOCAL_WORLD_SIZE] = str(local_world_size)
os.environ[xenv.PJRT_GROUP_RANK] = str(group_rank)
os.environ[xenv.PJRT_RANK] = str(local_rank + group_rank * local_world_size)
os.environ[xenv.PJRT_WORLD_SIZE] = str(world_size)
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
pjrt.initialize_multiprocess(local_rank, group_rank, local_world_size,
world_size)
return _mp_fn(local_rank, *args, **kwargs)
def master_fn(world_size):
gpu.shutdown_distributed_runtime()
gpu.initialize_distributed_runtime(world_size)
time.sleep(3600)
The first host runs the command:
PJRT_DIST_SERVICE_ADDR=0.0.0.0:30285 MASTER_PORT=30286 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PJRT_DEVICE=GPU GPU_NUM_DEVICES=8 python test_gpu_multi_hosts.py
world_size = 16
master_proc = Process(target=master_fn, args=(world_size,))
master_proc.start()
os.environ['GPU_NUM_DEVICES'] = str(world_size)
os.environ['PJRT_WORLD_SIZE'] = str(world_size)
configs = (
((0, 0, 2, world_size), {
'cuda_visible_devices': '0,1'
}),
((1, 0, 2, world_size), {
'cuda_visible_devices': '0,1'
}),
((0, 1, 2, world_size), {
'cuda_visible_devices': '2,3'
}),
((1, 1, 2, world_size), {
'cuda_visible_devices': '2,3'
}),
((0, 2, 2, world_size), {
'cuda_visible_devices': '4,5'
}),
((1, 2, 2, world_size), {
'cuda_visible_devices': '4,5'
}),
((0, 3, 2, world_size), {
'cuda_visible_devices': '6,7'
}),
((1, 3, 2, world_size), {
'cuda_visible_devices': '6,7'
}),
}
procs = []
for config in configs:
procs.append(Process(target=worker_fn, args=config[0], kwargs=config[1]))
procs[-1].start()
for p in procs:
p.join()
master_proc.kill()
The second host changes the world size to 16
and runs the command:
PJRT_DIST_SERVICE_ADDR=x.x.x.x:30285 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PJRT_DEVICE=GPU GPU_NUM_DEVICES=8 MASTER_PORT=30286 python test_gpu_multi_hosts.py
world_size = 16
master_proc = Process(target=master_fn, args=(world_size,))
master_proc.start()
os.environ['GPU_NUM_DEVICES'] = str(world_size)
os.environ['PJRT_WORLD_SIZE'] = str(world_size)
configs = (
((0, 4, 2, world_size), {
'cuda_visible_devices': '0,1'
}),
((1, 4, 2, world_size), {
'cuda_visible_devices': '0,1'
}),
((0, 5, 2, world_size), {
'cuda_visible_devices': '2,3'
}),
((1, 5, 2, world_size), {
'cuda_visible_devices': '2,3'
}),
((0, 6, 2, world_size), {
'cuda_visible_devices': '4,5'
}),
((1, 6, 2, world_size), {
'cuda_visible_devices': '4,5'
}),
((0, 7, 2, world_size), {
'cuda_visible_devices': '6,7'
}),
((1, 7, 2, world_size), {
'cuda_visible_devices': '6,7'
}),
}
procs = []
for config in configs:
procs.append(Process(target=worker_fn, args=config[0], kwargs=config[1]))
procs[-1].start()
for p in procs:
p.join()
master_proc.kill()
What do you need cuda_visible_devices
for?
What do you need
CUDA_VISIBLE_DEVICES
for?
In the computing world, we encounter scenarios where multiple hosts, groups, or local environments coexist. Each of these hosts, whether it's a group or a local environment, comprises numerous processes, and each process is associated with a specific GPU.
However, it's important to note that a host isn't necessarily equivalent to a group. This distinction becomes crucial when users have a common requirement: the need to partition the device resources of a host into distinct groups. This partitioning is similar to what you might encounter in cloud computing hybrid deployments or resource pooling scenarios.
So, how do we determine which GPUs have been allocated to a specific group? This is where the CUDA_VISIBLE_DEVICES
environment variable comes into play. CUDA_VISIBLE_DEVICES
is used to specify the devices that can be accessed by a given set of processes.
Let's illustrate this with an example: Suppose we set allowed_devices
to {local_rank}
. In this case, it means that the device corresponding to the local_rank
-th position in the list of CUDA_VISIBLE_DEVICES
will be considered the allowed device for that particular process.
For instance, if CUDA_VISIBLE_DEVICES
is defined as "4,5,6," and local_rank
is set to 1, it signifies that GPU 5 is the allowed device for this specific process.
Here are descriptions of some key distributed runtime arguments:
local_rank
: The rank of the current process within its local group.local_world_size
: The total number of processes within the local group.group_rank
: The rank of the worker group, typically ranging from 0 to max_nnodes
.rank
: The rank of the process in the global world.world_size
: The total size of the global world.CUDA_VISIBLE_DEVICES
: Specifies the visible devices for a given process.CUDA_VISIBLE_DEVICES is used to specify the devices that can be accessed by a given set of processes.
It seems you set it to an env var os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
. I wonder if it's only use in the test or also in the service code as I don't it being use in the service code.
Moreover, on a high level compared to what you have, the team wants a similar user experience as PyTorch's. The users would run a command:
$ torchrun
--nproc_per_node=4 # use 4 GPUs on this machine
--nnodes=2 # use 2 machines
--node_rank=0
--rdzv_id=456
--rdzv_backend=c10d
--rdzv_endpoint=..
multinode_torchrun.py 50 10
on each GPU machine. This is what we want to do at PyTorch/XLA level. Let me add you to the design doc. Is wbmc@163.com a good email for you? If so, you should have received an notification for the design doc.
Thanks!
CUDA_VISIBLE_DEVICES is used to specify the devices that can be accessed by a given set of processes.
It seems you set it to an env var
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices
. I wonder if it's only use in the test or also in the service code as I don't it being use in the service code.Moreover, on a high level compared to what you have, the team wants a similar user experience as PyTorch's. The users would run a command:
$ torchrun --nproc_per_node=4 # use 4 GPUs on this machine --nnodes=2 # use 2 machines --node_rank=0 --rdzv_id=456 --rdzv_backend=c10d --rdzv_endpoint=.. multinode_torchrun.py 50 10
on each GPU machine. This is what we want to do at PyTorch/XLA level. Let me add you to the design doc. Is wbmc@163.com a good email for you? If so, you should have received an notification for the design doc.
Thanks!
It appears that torchrun functions as a high-level interface. The PR method is compatible with torchrun and merely requires a few code modifications. For instance, users can execute the command with the provided arguments, and TorchXLA can subsequently translate them into environment variables, allowing the PJRT runtime to access the arguments seamlessly.
It seems that specifying which GPU to use is a common requirement. This is because GPUs may be occupied by other users and may not have continuous IDs available. For example, GPUs 0, 1, 4, and 6 may be available, but they do not have continuous IDs. Therefore, you need to select and specify which GPU to use. CUDA_VISIBLE_DEVICES
can solve the problem. You can find a discussion on this topic here.
I believe a more effective approach would be to provide the default device IDs to PJRT through the nproc_per_node
argument unless the CUDA_VISIBLE_DEVICES
environment variable is set. For instance, if nproc_per_node=2
, it would automatically allocate GPUs 0 and 1 to PJRT. Conversely, if the environment variable is set as CUDA_VISIBLE_DEVICES=3,4
and nproc_per_node=2
, it would allocate GPUs 3 and 4 to PJRT.
Is it support torchrun?
The support for torchrun has been added to PyTorch/XLA recently (pr). So you can use torchrun such as torchrun --nnodes 1 --nproc-per-node 4 pytorch/xla/test/pjrt/test_torchrun.py
.
Here are the reasons why we are in favor of the torchrun instead of the approach in your PR:
PJRT_DEVICE=GPU torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --rdzv_endpoint=<internal_ip_address> multinode_training.py
then we expect everything should work out in PyTorch/XLA. In contrast, your proposal would require the user to make more changes than we wanted, such as calling gpu.shutdown_distributed_runtime
, gpu.initialize_distributed_runtime
, pjrt.initialize_multiprocess
etc, which is not what we want.torchrun
can hide much of the implementation detail in your example:
procs = []
for config in configs:
procs.append(Process(target=worker_fn, args=config[0], kwargs=config[1]))
procs[-1].start()
for p in procs:
p.join()
master_proc.kill()
torchrun does it for you so that with torchrun users don't need to write these code.
torchrun
will set a few env var on our behalves so you don't have to set them ourselves. Setting those many env var in the code make it hard to debug.Is it necessary to set the CUDA_VISIBLE_DEVICES?
Sounds good to me.
The support for torchrun has been added to PyTorch/XLA recently (pr). So you can use torchrun such as
torchrun --nnodes 1 --nproc-per-node 4 pytorch/xla/test/pjrt/test_torchrun.py
.
OK, Let me have a try. Maybe requires a few code modifications.
fwiw, I create a poc and tested on a resnet model. By running commands on 2 GPU VMs respectively,
$ PJRT_DEVICE=GPU torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --rdzv_endpoint="10.164.0.1:12355" pytorch/xla/test/test_train_mp_imagenet_torchrun.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
$ PJRT_DEVICE=GPU torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --rdzv_endpoint="10.164.0.1:12355" pytorch/xla/test/test_train_mp_imagenet_torchrun.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
I verified it is working: https://gist.github.com/vanbasten23/87d29e5a763ed166f05710378a6be950
hey @wbmc , not sure about your progress. But to make the feature land sooner, I'd suggest we collaborate on my poc branch since I've validated it's working on resnet on multiple hosts. If you are fine with the plan, we can divide the remaining work. Let me know what you think.
hey @wbmc , not sure about your progress. But to make the feature land sooner, I'd suggest we collaborate on my poc branch since I've validated it's working on resnet on multiple hosts. If you are fine with the plan, we can divide the remaining work. Let me know what you think.
OK, I agree. Is this the branch? https://github.com/pytorch/xla/tree/multihostgpu_poc_3
Yes, that's the branch. The remaining work are:
How about you work on the first one and I work on the second one?
Yes, that's the branch. The remaining work are:
- set CUDA_VISIBLE_DEVICES to select specific GPUs in case some are occupied.
- tear down the distributed runtime service.
How about you work on the first one and I work on the second one?
Certainly, can I make changes to the branch by submitting code? If not, I'll have to create a fork of the branch. I don't have permission to directly submit code to pytorch/xla since I'm not a member. Could you please grant me membership access?
Certainly, can I make changes to the branch by submitting code?
Do you mean push commits to my branch? If so, yes we can work that way. Or would it be easier if we work on a fork?
Could you please grant me membership access?
Sorry I am not the admin of the repo so I'm unable to give you membership access.
Btw, Jack has granted you membership access.
Certainly, can I make changes to the branch by submitting code?
Do you mean push commits to my branch? If so, yes we can work that way. Or would it be easier if we work on a fork?
Could you please grant me membership access?
Sorry I am not the admin of the repo so I'm unable to give you membership access.
Yes, I mean push commits to your branch. I have accepted the invitation. Thank you!
hi @wbmc , how is the work going on your side?
hi @wbmc , how is the work going on your side?
I encountered some issues that haven't been resolved yet. After setting the CUDA_VISIBLE_DEVICES environment variable, I encountered an error when running 'all_gather'.
I encountered some issues that haven't been resolved yet. After setting the CUDA_VISIBLE_DEVICES environment variable, I encountered an error when running 'all_gather'.
What error did you get?
I encountered some issues that haven't been resolved yet. After setting the CUDA_VISIBLE_DEVICES environment variable, I encountered an error when running 'all_gather'.
What error did you get?
GPU_NUM_DEVICES=1 CUDA_VISIBLE_DEVICES=0 PJRT_DEVICE=GPU torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --rdzv_endpoint="127.0.0.1:12355" ./test/pjrt/test_torchrun.py >res1.txt 2>&1 &
GPU_NUM_DEVICES=1 CUDA_VISIBLE_DEVICES=1 PJRT_DEVICE=GPU torchrun --nproc_per_node=1 --nnodes=2 --node_rank=1 --rdzv_endpoint="127.0.0.1:12355" ./test/pjrt/test_torchrun.py >res2.txt 2>&1 &
[W socket.cpp:436] [c10d] The server socket cannot be initialized on [::]:12355 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:663] [c10d] The client socket cannot be initialized to connect to [localhost]:12355 (errno: 97 - Address family not supported by protocol).
Running tests under Python 3.8.17: /opt/conda/bin/python
[ RUN ] TestTorchrun.test_all_gather
[W socket.cpp:663] [c10d] The client socket cannot be initialized to connect to [localhost]:12355 (errno: 97 - Address family not supported by protocol).
dist_world_size 2 1 2
2023-10-07 10:47:38.596501: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: There was an error before calling cuModuleGetFunction (101): cudaErrorInvalidDevice : invalid device ordinal
2023-10-07 10:47:38.596637: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2622] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: There was an error before calling cuModuleGetFunction (101): cudaErrorInvalidDevice : invalid device ordinal; current tracing scope: wrapped_slice.2; current profiling annotation: XlaModule:#hlo_module=SyncTensorsGraph.14,program_id=0#.
[ FAILED ] TestTorchrun.test_all_gather
======================================================================
ERROR: test_all_gather (__main__.TestTorchrun)
TestTorchrun.test_all_gather
----------------------------------------------------------------------
Traceback (most recent call last):
File "./test/pjrt/test_torchrun.py", line 39, in test_all_gather
torch.testing.assert_close(result.cpu(), expected)
RuntimeError: Bad StatusOr access: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: There was an error before calling cuModuleGetFunction (101): cudaErrorInvalidDevice : invalid device ordinal; current tracing scope: wrapped_slice.2; current profiling annotation: XlaModule:#hlo_module=SyncTensorsGraph.14,program_id=0#.
----------------------------------------------------------------------
Ran 1 test in 2.236s
FAILED (errors=1)
Do you see the error when you test on the single host?
Also, from your command:
GPU_NUM_DEVICES=1 CUDA_VISIBLE_DEVICES=0 PJRT_DEVICE=GPU torchrun --nproc_per_node=1 --nnodes=2 --node_rank=0 --rdzv_endpoint="127.0.0.1:12355" ./test/pjrt/test_torchrun.py >res1.txt 2>&1 &
GPU_NUM_DEVICES=1 CUDA_VISIBLE_DEVICES=1 PJRT_DEVICE=GPU torchrun --nproc_per_node=1 --nnodes=2 --node_rank=1 --rdzv_endpoint="127.0.0.1:12355" ./test/pjrt/test_torchrun.py >res2.txt 2>&1 &
Are you testing the multi-host case on a single host?
Let me try something as well.
I also got a similar error when trying CUDA_VISIBLE_DEVICES:
2023-10-10 23:47:41.508210: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:2: failed initializing StreamExecutor for CUDA device ordinal 2: INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal
Need to look into it. Did you do something like this?
Also, how about we get the current pr merged and continue working on the CUDA_VISIBLE_DEVICES feature? Feel free to take a look.
I also got a similar error when trying CUDA_VISIBLE_DEVICES:
2023-10-10 23:47:41.508210: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:2: failed initializing StreamExecutor for CUDA device ordinal 2: INTERNAL: Failed call to cuDeviceGet: CUDA_ERROR_INVALID_DEVICE: invalid device ordinal
Need to look into it. Did you do something like this?
Also, how about we get the current pr merged and continue working on the CUDA_VISIBLE_DEVICES feature? Feel free to take a look.
Yes, we can merge the pr and continue working CUDA_VISIBLE_DEVICES feature.
🚀 Feature
Support GPU Multiple hosts in PJRT Runtime
Motivation
requirements of multiple hosts multiple GPUs
Pitch
Alternatives
Introduce TorchElastic style to describes the information of Distributed Runtime
Additional context