vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.42k stars 3.17k forks source link

[Bug]: Ray distributed backend does not support out-of-tree models via ModelRegistry APIs #5657

Open SamKG opened 2 weeks ago

SamKG commented 2 weeks ago

Your current environment

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Amazon Linux 2 (x86_64)
GCC version: (GCC) 7.3.1 20180712 (Red Hat 7.3.1-17)
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.26

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.217-205.860.amzn2.x86_64-x86_64-with-glibc2.26
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              96
On-line CPU(s) list: 0-95
Thread(s) per core:  2
Core(s) per socket:  24
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:            7
CPU MHz:             3599.164
BogoMIPS:            5999.99
Hypervisor vendor:   KVM
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            36608K
NUMA node0 CPU(s):   0-23,48-71
NUMA node1 CPU(s):   24-47,72-95
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0+cu121
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[conda] No relevant packages
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.0
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV12    NV12    NV12    NV12    NV12    NV12    NV12    0-23,48-71      0               N/A
GPU1    NV12     X      NV12    NV12    NV12    NV12    NV12    NV12    0-23,48-71      0               N/A
GPU2    NV12    NV12     X      NV12    NV12    NV12    NV12    NV12    0-23,48-71      0               N/A
GPU3    NV12    NV12    NV12     X      NV12    NV12    NV12    NV12    0-23,48-71      0               N/A
GPU4    NV12    NV12    NV12    NV12     X      NV12    NV12    NV12    24-47,72-95     1               N/A
GPU5    NV12    NV12    NV12    NV12    NV12     X      NV12    NV12    24-47,72-95     1               N/A
GPU6    NV12    NV12    NV12    NV12    NV12    NV12     X      NV12    24-47,72-95     1               N/A
GPU7    NV12    NV12    NV12    NV12    NV12    NV12    NV12     X      24-47,72-95     1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

The ray distributed backend does not support out-of-tree models (on a single node).

Repro:


from vllm import ModelRegistry
from vllm.model_executor.models.mixtral import MixtralForCausalLM
ModelRegistry.register_model("SomeModel", MixtralForCausalLM)

from vllm import LLM, SamplingParams

if __name__ == "__main__": 
    llm = LLM(
        model="SomeModel/", # just use a downloaded mixtral model from huggingface with config.json   "architectures": ["SomeModel"]
        tensor_parallel_size=8,
        # distributed_executor_backend="ray", # ray backend fails!
    )
youkaichao commented 2 weeks ago

will take a look later

youkaichao commented 2 weeks ago

sorry i don't get it. the usage of oot model registration, is that you register the architecture name appearing in the huggingface config file, not the LLM argument.

see https://huggingface.co/facebook/opt-125m/blob/main/config.json#L6 for example.

SamKG commented 2 weeks ago

sorry i don't get it. the usage of oot model registration, is that you register the architecture name appearing in the huggingface config file, not the LLM argument.

see https://huggingface.co/facebook/opt-125m/blob/main/config.json#L6 for example.

Yes, this is how I am using it. For context, the "SomeModel/" directory here contains a config.json file which references my custom architecture. For clarity, can use this example:

from vllm import ModelRegistry
from vllm.model_executor.models.mixtral import MixtralForCausalLM
ModelRegistry.register_model("SomeModel", MixtralForCausalLM)

from vllm import LLM, SamplingParams

if __name__ == "__main__": 
    llm = LLM(
        model="path_to_directory/", # directory which has a config.json with architectures: ["SomeModel"]
        tensor_parallel_size=8,
        # distributed_executor_backend="ray", # ray backend fails!
    )
youkaichao commented 2 weeks ago

then it makes sense to me. ray workers does not know "SomeModel", the following code:

from vllm import ModelRegistry
from vllm.model_executor.models.mixtral import MixtralForCausalLM
ModelRegistry.register_model("SomeModel", MixtralForCausalLM)

is not executed in ray workers.

SamKG commented 2 weeks ago

thanks! is there a way to do this initialization on the ray workers?

youkaichao commented 1 week ago

@SamKG so the default backend (multiprocessing) should work out-of-the-box, right?

richardliaw commented 1 week ago

also cc @rkooo567 - maybe this is solvable via runtime env

@SamKG is there a full repro somewhere we can look at?

SamKG commented 1 week ago

@SamKG is there a full repro somewhere we can look at?

@richardliaw Try attached. Note that the default backend will also fail (but with an expected error), since I added a stub tensor to keep the model directory small.

repro.tar.gz

@youkaichao yes, default backend works fine (as long as the OOT definition happens outside of main)

rkooo567 commented 1 week ago

ray.init(runtime_env={"worker_process_setup_hook": })... allows to execute code on all workers. Would this suffice?

youkaichao commented 1 week ago

@rkooo567 this functionality seems related, but how can we expose it to users?

SamKG commented 1 week ago

ray.init(runtime_env={"worker_process_setup_hook": })... allows to execute code on all workers. Would this suffice?

this seems to fix the issue!

import ray
from vllm import ModelRegistry, LLM

def _init_worker():
    from vllm.model_executor.models.mixtral import MixtralForCausalLM
    ModelRegistry.register_model("SomeModel", MixtralForCausalLM)

_init_worker()

if __name__ == "__main__":
    ray.init(runtime_env={"worker_process_setup_hook": _init_worker})
    llm = LLM(
        model="model/",
        tensor_parallel_size=8,
        distributed_executor_backend="ray",
    )
    llm.generate("test")
richardliaw commented 1 week ago

very nice!

@youkaichao maybe we can just print out a warning linking to the vllm docs about this?

and in the vllm docs let's have an example snippet like above!