Open youkaichao opened 5 months ago
What's worse, there are many cases inside init_distributed_environment
that can cause Exception
, and many synchronization point that can cause both main process and ray worker to wait for each other.
Any control divergence during this period (e.g. ray worker raised Exception while the main process is waiting for creating process group), causes a deadlock.
The core code of init_distributed_environment
involves the above two functions. And there are many, many possible Exception and synchronization points.
We need to come up with a better way for initializing distributed inference.
You should probably just have Ray pick up the first raised exception (via ray.wait
) and then kill the rest of the workers when that happens
You should probably just have Ray pick up the first raised exception (via
ray.wait
)
The problem is we don't know whether the worker will raise exception. Normally we expect all workers (plus main process) to run smoothly to initialize a process group, but here the main process has a difficult decision to make. It cannot wait and test worker exception while waiting for initializing a process group at the same time.
For future reference:
Some nightly build pytorch contains a bug that will initialize cuda context during import torch
. This makes the module not pickle-able, and will cause error. Combined with the deadlock mechanism discussed in this issue, these buggy torch versions will cause deadlock when used with vllm, as demonstrated in https://github.com/vllm-project/vllm/issues/3457 .
The code to detect whether we have a buggy torch version is:
# code borrowed from https://github.com/pytorch/pytorch/pull/117010
import torch
import ctypes
x = ctypes.c_int(-1)
# `ans` holds the error code, and `x` holds the device count
ans = ctypes.CDLL('libcuda.so.1').cuDeviceGetCount(ctypes.byref(x))
# normally, `import torch` does not initialize cuda, so we get CUDA_ERROR_NOT_INITIALIZED , which is 3
# check https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html for detailed error code
if ans == 3 and x.value == -1 :
print("your torch version is good!")
if ans == 0:
print("your torch version contains a bug!")
It seems some nightly build of pytorch (from torch-2.2.0.dev20231116
to torch-2.3.0.dev20231224
, or to be specific, any torch version contains code from this PR https://github.com/pytorch/pytorch/pull/112623 ) are affected.
It cannot wait and test worker exception while waiting for initializing a process group at the same time.
Can't you just use multithreading, one to do ray.wait
and the other to do dist.init_process_group
?
The easy and cleaner alternative is simply to put the result of driver into ray object store and then always call ray.get()
on all result objectrefs
This idea does not work; always need concurrent polling if len(workers) > 0
Will be fixed by: https://github.com/vllm-project/vllm/issues/6556
Your current environment
Any distributed inference tasks with ray currently suffer from this issue.
🐛 Describe the bug
Basic background of
ray
ray
provides an easy-to-use asynchronous execution framework:The way it deals with
Exception
is noteworthy, see comments in the below:The deadlock in distributed inference
The deadlock happens during initialization of distributed inference, i.e. creating process group to collaborate.
A minimal reproducible example looks like this:
Normally it works with the following output:
However, if the
f
function throws an exception before callingdist.init_process_group
, it will be kept in an error state, waiting for the main process to callray.get
to error out; meanwhile, the main process is stuck atdist.init_process_group
, waiting for the worker process to join to initialize the process group for distributed environment. Together they caused a deadlock.How is this related with
vLLM
vLLM
usesray
for distributed inference, and the core code is attached below:https://github.com/vllm-project/vllm/blob/6b78837b29b5045a71e6ecfa68442b1f4fd2d0a6/vllm/executor/ray_gpu_executor.py#L299-L351
When calling
init_model
, both ray worker and the main process will reach the following function:https://github.com/vllm-project/vllm/blob/abfc4f3387c436d46d6701e9ba916de8f9ed9329/vllm/worker/worker.py#L71-L96
And essentially we are back to the minimal reproducible example mentioned before. All of the exception before
init_distributed_environment
can cause deadlock.In my case, my GPU driver has some problem, and
torch.cuda.set_device
raises an exception, causing the deadlock.Solution to be discussed
Any suggestion to fix this is welcome.
Might be related: https://github.com/vllm-project/vllm/issues/2466 .