vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.6k stars 3.9k forks source link

[Bug]: Dead lock in distributed inference when ray worker raises an exception #3455

Open youkaichao opened 5 months ago

youkaichao commented 5 months ago

Your current environment

Any distributed inference tasks with ray currently suffer from this issue.

🐛 Describe the bug

Basic background of ray

ray provides an easy-to-use asynchronous execution framework:

def f():
    print(1)

import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle
result = ray.get(handle) # synchronously wait for the worker to finish and return the result

The way it deals with Exception is noteworthy, see comments in the below:

def f():
    print(1)
    raise RuntimeError("test")
    # the following line will not be executed
    print(2)

import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle

# ... do other work in the meantime ...
# the main process will not be notified if the worker fails

# only when we call `ray.get` will we be notified of the error
result = ray.get(handle) # raise the error that was thrown in the worker, wrapping it in a RayTaskError

The deadlock in distributed inference

The deadlock happens during initialization of distributed inference, i.e. creating process group to collaborate.

A minimal reproducible example looks like this:

import torch
import torch.distributed as dist

def f(rank, world_size, distributed_init_method):
    # raise RuntimeError # uncoment this line to see a deadlock
    dist.init_process_group(
        backend="gloo",
        init_method=distributed_init_method,
        world_size=world_size,
        rank=rank,
    )
    tensor = torch.zeros(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print(f"Rank {rank} has data {tensor.item()}")

import ray
ray.init()
marked_function = ray.remote(f)

distributed_init_method = "tcp://127.0.0.1:29500"
world_size = 2

# start the first process
handle = marked_function.remote(rank=0, world_size=world_size, distributed_init_method=distributed_init_method)

# the main process is the second process
# wait for the first process to join here to initialize the process group for distributed environment
dist.init_process_group(backend="gloo", init_method=distributed_init_method, world_size=world_size, rank=1)

# two processes are ready to communicate
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"Rank 1 has data {tensor.item()}")

result = ray.get(handle)

Normally it works with the following output:

2024-03-17 10:24:23,293 INFO worker.py:1724 -- Started a local Ray instance.
Rank 1 has data 1.0
(f pid=14616) Rank 0 has data 1.0

However, if the f function throws an exception before calling dist.init_process_group, it will be kept in an error state, waiting for the main process to call ray.get to error out; meanwhile, the main process is stuck at dist.init_process_group, waiting for the worker process to join to initialize the process group for distributed environment. Together they caused a deadlock.

How is this related with vLLM

vLLM uses ray for distributed inference, and the core code is attached below:

https://github.com/vllm-project/vllm/blob/6b78837b29b5045a71e6ecfa68442b1f4fd2d0a6/vllm/executor/ray_gpu_executor.py#L299-L351

When calling init_model, both ray worker and the main process will reach the following function:

https://github.com/vllm-project/vllm/blob/abfc4f3387c436d46d6701e9ba916de8f9ed9329/vllm/worker/worker.py#L71-L96

And essentially we are back to the minimal reproducible example mentioned before. All of the exception before init_distributed_environment can cause deadlock.

In my case, my GPU driver has some problem, and torch.cuda.set_device raises an exception, causing the deadlock.

Solution to be discussed

Any suggestion to fix this is welcome.

Might be related: https://github.com/vllm-project/vllm/issues/2466 .

youkaichao commented 5 months ago

What's worse, there are many cases inside init_distributed_environment that can cause Exception, and many synchronization point that can cause both main process and ray worker to wait for each other.

Any control divergence during this period (e.g. ray worker raised Exception while the main process is waiting for creating process group), causes a deadlock.

https://github.com/vllm-project/vllm/blob/abfc4f3387c436d46d6701e9ba916de8f9ed9329/vllm/worker/worker.py#L252-L305

https://github.com/vllm-project/vllm/blob/abfc4f3387c436d46d6701e9ba916de8f9ed9329/vllm/model_executor/parallel_utils/cupy_utils.py#L70-L96

The core code of init_distributed_environment involves the above two functions. And there are many, many possible Exception and synchronization points.

We need to come up with a better way for initializing distributed inference.

richardliaw commented 5 months ago

You should probably just have Ray pick up the first raised exception (via ray.wait) and then kill the rest of the workers when that happens

youkaichao commented 5 months ago

You should probably just have Ray pick up the first raised exception (via ray.wait)

The problem is we don't know whether the worker will raise exception. Normally we expect all workers (plus main process) to run smoothly to initialize a process group, but here the main process has a difficult decision to make. It cannot wait and test worker exception while waiting for initializing a process group at the same time.

youkaichao commented 5 months ago

For future reference:

Some nightly build pytorch contains a bug that will initialize cuda context during import torch. This makes the module not pickle-able, and will cause error. Combined with the deadlock mechanism discussed in this issue, these buggy torch versions will cause deadlock when used with vllm, as demonstrated in https://github.com/vllm-project/vllm/issues/3457 .

The code to detect whether we have a buggy torch version is:

# code borrowed from https://github.com/pytorch/pytorch/pull/117010

import torch
import ctypes
x = ctypes.c_int(-1)
# `ans` holds the error code, and `x` holds the device count
ans = ctypes.CDLL('libcuda.so.1').cuDeviceGetCount(ctypes.byref(x))

# normally, `import torch` does not initialize cuda, so we get CUDA_ERROR_NOT_INITIALIZED , which is 3
# check https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html for detailed error code
if ans == 3 and x.value == -1 :
    print("your torch version is good!")

if ans == 0:
    print("your torch version contains a bug!")

It seems some nightly build of pytorch (from torch-2.2.0.dev20231116 to torch-2.3.0.dev20231224, or to be specific, any torch version contains code from this PR https://github.com/pytorch/pytorch/pull/112623 ) are affected.

jon-chuang commented 1 month ago

It cannot wait and test worker exception while waiting for initializing a process group at the same time.

Can't you just use multithreading, one to do ray.wait and the other to do dist.init_process_group?

jon-chuang commented 1 month ago

The easy and cleaner alternative is simply to put the result of driver into ray object store and then always call ray.get() on all result objectrefs

This idea does not work; always need concurrent polling if len(workers) > 0

jon-chuang commented 1 month ago

Will be fixed by: https://github.com/vllm-project/vllm/issues/6556