vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.3k stars 4.75k forks source link

[Bug]: No CUDA GPUs are available on 'CPU' use #4858

Closed mcr-ksh closed 1 week ago

mcr-ksh commented 6 months ago

Your current environment

PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.2
Libc version: glibc-2.35

Python version: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      43 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             12
On-line CPU(s) list:                0-11
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 5 6600H with Radeon Graphics
CPU family:                         23
Model:                              0
Thread(s) per core:                 1
Core(s) per socket:                 1
Socket(s):                          12
Stepping:                           0
BogoMIPS:                           6587.62
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw cpb ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec clzero arat overflow_recov succor
L1d cache:                          384 KiB (12 instances)
L1i cache:                          384 KiB (12 instances)
L2 cache:                           6 MiB (12 instances)
L3 cache:                           192 MiB (12 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT disabled
Vulnerability Spec rstack overflow: Mitigation; SMT disabled
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.18.1
[pip3] torch==2.1.2
[pip3] triton==2.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
[conda] torch                     2.1.2                    pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypiROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

🐛 Describe the bug

Starting using --device 'cpu' throws CUDA error. Should run on CPU anyways

(ldb) root@irsai:/devel/LLMDebugger# python -m vllm.entrypoints.openai.api_server --model bigcode/starcoder --device cpu

INFO 05-16 11:41:07 api_server.py:149] vLLM API server version 0.4.0.post1
INFO 05-16 11:41:07 api_server.py:150] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, served_model_name=None, lora_modules=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='bigcode/starcoder', tokenizer=None, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, gpu_memory_utilization=0.9, forced_num_gpu_blocks=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=5, disable_log_stats=False, quantization=None, enforce_eager=False, max_context_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', max_cpu_loras=None, device='cpu', image_input_type=None, image_token_id=None, image_input_shape=None, image_feature_size=None, scheduler_delay_factor=0.0, enable_chunked_prefill=False, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
INFO 05-16 11:41:07 llm_engine.py:74] Initializing an LLM engine (v0.4.0.post1) with config: model='bigcode/starcoder', tokenizer='bigcode/starcoder', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cpu, seed=0)
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 157, in <module>
    engine = AsyncLLMEngine.from_engine_args(
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 348, in from_engine_args
    engine = cls(
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 311, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 422, in _init_engine
    return engine_class(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 110, in __init__
    self.model_executor = executor_class(model_config, cache_config,
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 37, in __init__
    self._init_worker()
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 52, in _init_worker
    self.driver_worker = Worker(
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/worker/worker.py", line 63, in __init__
    self.model_runner = ModelRunner(
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 90, in __init__
    self.attn_backend = get_attn_backend(
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/attention/selector.py", line 15, in get_attn_backend
    if _can_use_flash_attn(dtype):
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/vllm/attention/selector.py", line 38, in _can_use_flash_attn
    if torch.cuda.get_device_capability()[0] < 8:
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/torch/cuda/__init__.py", line 435, in get_device_capability
    prop = get_device_properties(device)
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/torch/cuda/__init__.py", line 449, in get_device_properties
    _lazy_init()  # will define _get_device_properties
  File "/home/ubuntu/miniconda3/envs/ldb/lib/python3.10/site-packages/torch/cuda/__init__.py", line 298, in _lazy_init
    torch._C._cuda_init()
RuntimeError: No CUDA GPUs are available
farshadghodsian commented 6 months ago

Although this doesn't solve the bug if you would like to get things working and disable vllm from trying to use your integrated Radeon Graphics you can set CUDA_VISIBLE_DEVICES=-1. I tried setting --device=cpu and it is working correctly for me.

casassg commented 6 months ago

+1 to this issue, seems the error is caused when you install vllm without cpu version. Currently attention backend is decided based on wether the installed version of vllm has cpu suffix or not (https://github.com/vllm-project/vllm/blob/main/vllm/attention/selector.py#L84 -> https://github.com/vllm-project/vllm/blob/main/vllm/utils.py#L131). This means that even when you specify device to be cpu vllm tries to load from other attention backends.

https://github.com/vllm-project/vllm/pull/4962 is a potential solution (effectively passing down cpu attention backend flag from worker)

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

github-actions[bot] commented 1 week ago

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!