vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.3k stars 4.75k forks source link

[Bug]: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details) #7548

Open zhaotyer opened 3 months ago

zhaotyer commented 3 months ago

Your current environment

The output of `python collect_env.py` ```text root@newllm201:/workspace# vim collect.py root@newllm201:/workspace# python3 collect.py Collecting environment information... PyTorch version: 2.3.1+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: version 3.29.0 Libc version: glibc-2.31 Python version: 3.8.10 (default, Jul 29 2024, 17:02:10) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.29 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB GPU 1: NVIDIA A100-SXM4-80GB GPU 2: NVIDIA A100-SXM4-80GB GPU 3: NVIDIA A100-SXM4-80GB Nvidia driver version: 535.104.05 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: 架构: x86_64 CPU 运行模式: 32-bit, 64-bit 字节序: Little Endian Address sizes: 52 bits physical, 57 bits virtual CPU: 112 在线 CPU 列表: 0-111 每个核的线程数: 2 每个座的核数: 28 座: 2 NUMA 节点: 2 厂商 ID: GenuineIntel CPU 系列: 6 型号: 106 型号名称: Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz 步进: 6 Frequency boost: enabled CPU MHz: 1100.000 CPU 最大 MHz: 2601.0000 CPU 最小 MHz: 800.0000 BogoMIPS: 5200.00 虚拟化: VT-x L1d 缓存: 2.6 MiB L1i 缓存: 1.8 MiB L2 缓存: 70 MiB L3 缓存: 84 MiB NUMA 节点0 CPU: 0-27,56-83 NUMA 节点1 CPU: 28-55,84-111 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; Load fences, usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected 标记: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 invpcid_single intel_pt ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq md_clear pconfig spec_ctrl intel_stibp flush_l1d arch_capabilities Versions of relevant libraries: [pip3] numpy==1.24.4 [pip3] nvidia-nccl-cu11==2.20.5 [pip3] onnx==1.15.0 [pip3] paddle2onnx==1.1.0 [pip3] torch==2.3.1+cu118 [pip3] torchaudio==2.3.1+cu118 [pip3] torchtext==0.5.0 [pip3] torchvision==0.18.1+cu118 [pip3] triton==2.3.1 [pip3] tritonclient==2.19.0 [conda] Could not collectROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.5.3.post1 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 GPU1 GPU2 GPU3 NIC0 NIC1 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X SYS SYS SYS PXB PXB 0-27,56-83 0 N/A GPU1 SYS X NV12 PXB SYS SYS 28-55,84-111 1 N/A GPU2 SYS NV12 X PXB SYS SYS 28-55,84-111 1 N/A GPU3 SYS PXB PXB X SYS SYS 28-55,84-111 1 N/A NIC0 PXB SYS SYS SYS X PIX NIC1 PXB SYS SYS SYS PIX X Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks NIC Legend: NIC0: mlx5_0 NIC1: mlx5_1 ```

🐛 Describe the bug

LLM model is: Qwen/Qwen2-72B-Instruct Execute command is:

 NCCL_DEBUG=WARN python3 -m vllm.entrypoints.openai.api_server   --model=/workspace/atom/1/local_model/base_model/ -tp 4

Error info is:

root@newllm201:/workspace# NCCL_DEBUG=WARN python3 -m vllm.entrypoints.openai.api_server --model=/workspace/atom/1/local_model/base_model/ -tp 4
INFO 08-15 07:47:02 api_server.py:219] vLLM API server version 0.5.3.post1
INFO 08-15 07:47:02 api_server.py:220] args: Namespace(allow_credentials=False, allowed_headers=['*'], allowed_methods=['*'], allowed_origins=['*'], api_key=None, block_size=16, chat_template=None, code_revision=None, cpu_offload_gb=0, device='auto', disable_custom_all_reduce=False, disable_log_requests=False, disable_log_stats=False, disable_logprobs_during_spec_decoding=None, disable_sliding_window=False, distributed_executor_backend=None, download_dir=None, dtype='auto', enable_chunked_prefill=None, enable_lora=False, enable_prefix_caching=False, enable_prompt_adapter=False, enforce_eager=False, engine_use_ray=False, fully_sharded_loras=False, gpu_memory_utilization=0.9, guided_decoding_backend='outlines', host=None, ignore_patterns=[], kv_cache_dtype='auto', load_format='auto', long_lora_scaling_factors=None, lora_dtype='auto', lora_extra_vocab_size=256, lora_modules=None, max_context_len_to_capture=None, max_cpu_loras=None, max_log_len=None, max_logprobs=20, max_lora_rank=16, max_loras=1, max_model_len=None, max_num_batched_tokens=None, max_num_seqs=256, max_parallel_loading_workers=None, max_prompt_adapter_token=0, max_prompt_adapters=1, max_seq_len_to_capture=8192, middleware=[], model='/workspace/atom/1/local_model/base_model/', model_loader_extra_config=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, num_gpu_blocks_override=None, num_lookahead_slots=0, num_speculative_tokens=None, otlp_traces_endpoint=None, pipeline_parallel_size=1, port=8000, preemption_mode=None, prompt_adapters=None, qlora_adapter_name_or_path=None, quantization=None, quantization_param_path=None, ray_workers_use_nsight=False, response_role='assistant', revision=None, root_path=None, rope_scaling=None, rope_theta=None, scheduler_delay_factor=0.0, seed=0, served_model_name=None, skip_tokenizer_init=False, spec_decoding_acceptance_method='rejection_sampler', speculative_disable_by_batch_size=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_model=None, ssl_ca_certs=None, ssl_cert_reqs=0, ssl_certfile=None, ssl_keyfile=None, swap_space=4, tensor_parallel_size=4, tokenizer=None, tokenizer_mode='auto', tokenizer_pool_extra_config=None, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_revision=None, trust_remote_code=False, typical_acceptance_sampler_posterior_alpha=None, typical_acceptance_sampler_posterior_threshold=None, use_v2_block_manager=False, uvicorn_log_level='info', worker_use_ray=False)
INFO 08-15 07:47:02 config.py:715] Defaulting to use mp for distributed inference
WARNING 08-15 07:47:02 arg_utils.py:762] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.
INFO 08-15 07:47:02 config.py:806] Chunked prefill is enabled with max_num_batched_tokens=512.
INFO 08-15 07:47:02 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='/workspace/atom/1/local_model/base_model/', speculative_config=None, tokenizer='/workspace/atom/1/local_model/base_model/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/workspace/atom/1/local_model/base_model/, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 08-15 07:47:03 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
(VllmWorkerProcess pid=691) INFO 08-15 07:47:03 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=692) INFO 08-15 07:47:03 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=693) INFO 08-15 07:47:03 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 08-15 07:47:05 utils.py:784] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=691) INFO 08-15 07:47:05 utils.py:784] Found nccl from library libnccl.so.2
INFO 08-15 07:47:05 pynccl.py:63] vLLM is using nccl==2.20.5
(VllmWorkerProcess pid=692) INFO 08-15 07:47:05 utils.py:784] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=691) INFO 08-15 07:47:05 pynccl.py:63] vLLM is using nccl==2.20.5
(VllmWorkerProcess pid=692) INFO 08-15 07:47:05 pynccl.py:63] vLLM is using nccl==2.20.5
(VllmWorkerProcess pid=693) INFO 08-15 07:47:05 utils.py:784] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=693) INFO 08-15 07:47:05 pynccl.py:63] vLLM is using nccl==2.20.5
NCCL version 2.20.5+cuda11.0
(VllmWorkerProcess pid=693) WARNING 08-15 07:47:07 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=691) WARNING 08-15 07:47:07 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=692) WARNING 08-15 07:47:07 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
WARNING 08-15 07:47:07 custom_all_reduce.py:118] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 08-15 07:47:07 shm_broadcast.py:241] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7f97a5bad880>, local_subscribe_port=48358, local_sync_port=43845, remote_subscribe_port=None, remote_sync_port=None)
INFO 08-15 07:47:07 model_runner.py:680] Starting to load model /workspace/atom/1/local_model/base_model/...
(VllmWorkerProcess pid=693) INFO 08-15 07:47:07 model_runner.py:680] Starting to load model /workspace/atom/1/local_model/base_model/...
(VllmWorkerProcess pid=692) INFO 08-15 07:47:07 model_runner.py:680] Starting to load model /workspace/atom/1/local_model/base_model/...
(VllmWorkerProcess pid=691) INFO 08-15 07:47:07 model_runner.py:680] Starting to load model /workspace/atom/1/local_model/base_model/...
Loading safetensors checkpoint shards:   0% Completed | 0/37 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   3% Completed | 1/37 [00:00<00:25,  1.39it/s]
Loading safetensors checkpoint shards:   5% Completed | 2/37 [00:01<00:34,  1.03it/s]
Loading safetensors checkpoint shards:   8% Completed | 3/37 [00:02<00:30,  1.10it/s]
Loading safetensors checkpoint shards:  11% Completed | 4/37 [00:03<00:29,  1.11it/s]
Loading safetensors checkpoint shards:  14% Completed | 5/37 [00:04<00:24,  1.31it/s]
Loading safetensors checkpoint shards:  16% Completed | 6/37 [00:04<00:21,  1.44it/s]
Loading safetensors checkpoint shards:  19% Completed | 7/37 [00:05<00:19,  1.55it/s]
Loading safetensors checkpoint shards:  22% Completed | 8/37 [00:05<00:17,  1.62it/s]
Loading safetensors checkpoint shards:  24% Completed | 9/37 [00:06<00:17,  1.56it/s]
Loading safetensors checkpoint shards:  27% Completed | 10/37 [00:07<00:23,  1.13it/s]
Loading safetensors checkpoint shards:  30% Completed | 11/37 [00:08<00:23,  1.11it/s]
Loading safetensors checkpoint shards:  32% Completed | 12/37 [00:09<00:22,  1.13it/s]
Loading safetensors checkpoint shards:  35% Completed | 13/37 [00:10<00:21,  1.11it/s]
Loading safetensors checkpoint shards:  38% Completed | 14/37 [00:11<00:20,  1.14it/s]
Loading safetensors checkpoint shards:  41% Completed | 15/37 [00:12<00:18,  1.17it/s]
Loading safetensors checkpoint shards:  43% Completed | 16/37 [00:13<00:17,  1.19it/s]
Loading safetensors checkpoint shards:  46% Completed | 17/37 [00:13<00:16,  1.18it/s]
Loading safetensors checkpoint shards:  49% Completed | 18/37 [00:14<00:15,  1.20it/s]
Loading safetensors checkpoint shards:  51% Completed | 19/37 [00:15<00:13,  1.29it/s]
Loading safetensors checkpoint shards:  54% Completed | 20/37 [00:15<00:12,  1.38it/s]
Loading safetensors checkpoint shards:  57% Completed | 21/37 [00:16<00:10,  1.49it/s]
Loading safetensors checkpoint shards:  59% Completed | 22/37 [00:17<00:09,  1.62it/s]
Loading safetensors checkpoint shards:  62% Completed | 23/37 [00:17<00:08,  1.65it/s]
Loading safetensors checkpoint shards:  65% Completed | 24/37 [00:18<00:08,  1.58it/s]
Loading safetensors checkpoint shards:  68% Completed | 25/37 [00:19<00:08,  1.44it/s]
Loading safetensors checkpoint shards:  70% Completed | 26/37 [00:19<00:08,  1.34it/s]
Loading safetensors checkpoint shards:  73% Completed | 27/37 [00:20<00:07,  1.25it/s]
Loading safetensors checkpoint shards:  76% Completed | 28/37 [00:21<00:07,  1.20it/s]
Loading safetensors checkpoint shards:  78% Completed | 29/37 [00:22<00:06,  1.17it/s]
Loading safetensors checkpoint shards:  81% Completed | 30/37 [00:23<00:06,  1.13it/s]
Loading safetensors checkpoint shards:  84% Completed | 31/37 [00:24<00:04,  1.26it/s]
Loading safetensors checkpoint shards:  86% Completed | 32/37 [00:24<00:03,  1.38it/s]
Loading safetensors checkpoint shards:  89% Completed | 33/37 [00:25<00:02,  1.41it/s]
Loading safetensors checkpoint shards:  92% Completed | 34/37 [00:26<00:02,  1.36it/s]
Loading safetensors checkpoint shards:  95% Completed | 35/37 [00:27<00:01,  1.29it/s]
Loading safetensors checkpoint shards:  97% Completed | 36/37 [00:28<00:00,  1.23it/s]
(VllmWorkerProcess pid=693) INFO 08-15 07:47:36 model_runner.py:692] Loading model weights took 34.0067 GB
(VllmWorkerProcess pid=691) INFO 08-15 07:47:36 model_runner.py:692] Loading model weights took 34.0067 GB
Loading safetensors checkpoint shards: 100% Completed | 37/37 [00:28<00:00,  1.24it/s]
Loading safetensors checkpoint shards: 100% Completed | 37/37 [00:28<00:00,  1.28it/s]

(VllmWorkerProcess pid=692) INFO 08-15 07:47:37 model_runner.py:692] Loading model weights took 34.0067 GB
INFO 08-15 07:47:37 model_runner.py:692] Loading model weights took 34.0067 GB
INFO 08-15 07:47:39 distributed_gpu_executor.py:56] # GPU blocks: 29158, # CPU blocks: 3276
(VllmWorkerProcess pid=691) INFO 08-15 07:47:43 model_runner.py:980] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=691) INFO 08-15 07:47:43 model_runner.py:984] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 08-15 07:47:43 model_runner.py:980] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 08-15 07:47:43 model_runner.py:984] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=693) INFO 08-15 07:47:43 model_runner.py:980] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=693) INFO 08-15 07:47:43 model_runner.py:984] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=692) INFO 08-15 07:47:43 model_runner.py:980] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=692) INFO 08-15 07:47:43 model_runner.py:984] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.

newllm201:623:623 [0] misc/strongstream.cc:53 NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465.

newllm201:693:693 [3] misc/strongstream.cc:53 NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465.
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method initialize_cache: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details), Traceback (most recent call last):
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/worker.py", line 220, in initialize_cache
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     self._warm_up_model()
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/worker.py", line 236, in _warm_up_model
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     self.model_runner.capture_model(self.gpu_cache)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/model_runner.py", line 1173, in capture_model
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     graph_runner.capture(**capture_inputs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/model_runner.py", line 1411, in capture
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     output_hidden_or_intermediate_states = self.model(
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 336, in forward
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     hidden_states = self.model(input_ids, positions, kv_caches,
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 253, in forward
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     hidden_states = self.embed_tokens(input_ids)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 352, in forward
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     output = tensor_model_parallel_all_reduce(output_parallel)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/communication_op.py", line 11, in tensor_model_parallel_all_reduce
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     return get_tp_group().all_reduce(input_)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/parallel_state.py", line 291, in all_reduce
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     pynccl_comm.all_reduce(input_)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 118, in all_reduce
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 257, in ncclAllReduce
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 223, in NCCL_CHECK
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226]     raise RuntimeError(f"NCCL error: {error_str}")
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226] RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)
(VllmWorkerProcess pid=693) ERROR 08-15 07:47:44 multiproc_worker_utils.py:226] 
[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/entrypoints/openai/api_server.py", line 317, in <module>
[rank0]:     run_server(args)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/entrypoints/openai/api_server.py", line 231, in run_server
[rank0]:     if llm_engine is not None else AsyncLLMEngine.from_engine_args(
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 466, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 380, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/engine/async_llm_engine.py", line 547, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/engine/llm_engine.py", line 265, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/engine/llm_engine.py", line 377, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/executor/distributed_gpu_executor.py", line 62, in initialize_cache
[rank0]:     self._run_workers("initialize_cache",
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 178, in _run_workers
[rank0]:     driver_worker_output = driver_worker_method(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/worker.py", line 220, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/worker.py", line 236, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/model_runner.py", line 1173, in capture_model
[rank0]:     graph_runner.capture(**capture_inputs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/worker/model_runner.py", line 1411, in capture
[rank0]:     output_hidden_or_intermediate_states = self.model(
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 336, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/models/qwen2.py", line 253, in forward
[rank0]:     hidden_states = self.embed_tokens(input_ids)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 352, in forward
[rank0]:     output = tensor_model_parallel_all_reduce(output_parallel)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/communication_op.py", line 11, in tensor_model_parallel_all_reduce
[rank0]:     return get_tp_group().all_reduce(input_)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/parallel_state.py", line 291, in all_reduce
[rank0]:     pynccl_comm.all_reduce(input_)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/device_communicators/pynccl.py", line 118, in all_reduce
[rank0]:     self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 257, in ncclAllReduce
[rank0]:     self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/vllm/distributed/device_communicators/pynccl_wrapper.py", line 223, in NCCL_CHECK
[rank0]:     raise RuntimeError(f"NCCL error: {error_str}")
[rank0]: RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)
ERROR 08-15 07:47:46 multiproc_worker_utils.py:120] Worker VllmWorkerProcess pid 691 died, exit code: -15
INFO 08-15 07:47:46 multiproc_worker_utils.py:123] Killing local vLLM worker processes
/usr/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

add --enforce-eager it's work well

zhaotyer commented 3 months ago

@youkaichao Please take a look

youkaichao commented 3 months ago

the error message is clear:

NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465

upgrade your driver and use cuda 12.1 should fix it.

Otherwise, add --enforce-eager but it might hurt performance.

zhaotyer commented 3 months ago

the error message is clear:

NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465

upgrade your driver and use cuda 12.1 should fix it.

Otherwise, add --enforce-eager but it might hurt performance.

It probably has nothing to do with the driver version.It doesn't work on another server with driver 550 either.It should be related to the version of nccl, pytorch is normal, PyNcclCommunicator is not

zhaotyer commented 3 months ago

the error message is clear:

NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465

upgrade your driver and use cuda 12.1 should fix it. Otherwise, add --enforce-eager but it might hurt performance.

It probably has nothing to do with the driver version.It doesn't work on another server with driver 550 either.It should be related to the version of nccl, pytorch is normal, PyNcclCommunicator is not

The versions of nvidia-nccl-cu12 and nvidia-nccl-cu11 are inconsistent. Now there are problems with the vllm cuda118 version.

zhaotyer commented 3 months ago

the error message is clear:

NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465

upgrade your driver and use cuda 12.1 should fix it. Otherwise, add --enforce-eager but it might hurt performance.

It probably has nothing to do with the driver version.It doesn't work on another server with driver 550 either.It should be related to the version of nccl, pytorch is normal, PyNcclCommunicator is not

The versions of nvidia-nccl-cu12 and nvidia-nccl-cu11 are inconsistent. Now there are problems with the vllm cuda118 version.

Recompiling nccl in cuda118 can solve this problem,nccl in nvidia-nccl-cu11 is based on cuda110 and cannot use the stream feature

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!