vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.89k stars 4.51k forks source link

[Bug]: Mixtral 8-way TP with --enable-lora crashes with CUDA illegal memory access error #6902

Open tjohnson31415 opened 3 months ago

tjohnson31415 commented 3 months ago

Your current environment

Output of `python collect_env.py` ```text Collecting environment information... PyTorch version: 2.3.1+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A OS: Red Hat Enterprise Linux 9.4 (Plow) (x86_64) GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3) Clang version: Could not collect CMake version: version 3.30.0 Libc version: glibc-2.34 Python version: 3.11.7 (main, May 16 2024, 00:00:00) [GCC 11.4.1 20231218 (Red Hat 11.4.1-3)] (64-bit runtime) Python platform: Linux-5.14.0-284.52.1.el9_2.x86_64-x86_64-with-glibc2.34 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB GPU 1: NVIDIA A100-SXM4-80GB GPU 2: NVIDIA A100-SXM4-80GB GPU 3: NVIDIA A100-SXM4-80GB GPU 4: NVIDIA A100-SXM4-80GB GPU 5: NVIDIA A100-SXM4-80GB GPU 6: NVIDIA A100-SXM4-80GB GPU 7: NVIDIA A100-SXM4-80GB Nvidia driver version: 550.54.15 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 80 On-line CPU(s) list: 0-79 Vendor ID: GenuineIntel Model name: Intel Xeon Processor (Icelake) CPU family: 6 Model: 134 Thread(s) per core: 2 Core(s) per socket: 20 Socket(s): 2 Stepping: 0 BogoMIPS: 5600.03 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm md_clear arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 2.5 MiB (80 instances) L1i cache: 2.5 MiB (80 instances) L2 cache: 160 MiB (40 instances) L3 cache: 32 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-39 NUMA node1 CPU(s): 40-79 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Versions of relevant libraries: [pip3] flashinfer==0.0.9+cu121torch2.3 [pip3] numpy==1.26.4 [pip3] nvidia-nccl-cu12==2.20.5 [pip3] torch==2.3.1 [pip3] torchvision==0.18.1 [pip3] transformers==4.42.4 [pip3] triton==2.3.1 [conda] Could not collect ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.5.3 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12 SYS SYS PIX PIX PIX 0-39 0 N/A GPU1 NV12 X NV12 NV12 NV12 NV12 NV12 NV12 SYS SYS PIX PIX PIX 0-39 0 N/A GPU2 NV12 NV12 X NV12 NV12 NV12 NV12 NV12 SYS SYS SYS SYS SYS 0-39 0 N/A GPU3 NV12 NV12 NV12 X NV12 NV12 NV12 NV12 SYS SYS SYS SYS SYS 0-39 0 N/A GPU4 NV12 NV12 NV12 NV12 X NV12 NV12 NV12 PIX PIX SYS SYS SYS 40-79 1 N/A GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12 PIX PIX SYS SYS SYS 40-79 1 N/A GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12 SYS SYS SYS SYS SYS 40-79 1 N/A GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X SYS SYS SYS SYS SYS 40-79 1 N/A NIC0 SYS SYS SYS SYS PIX PIX SYS SYS X PIX SYS SYS SYS NIC1 SYS SYS SYS SYS PIX PIX SYS SYS PIX X SYS SYS SYS NIC2 PIX PIX SYS SYS SYS SYS SYS SYS SYS SYS X PIX PIX NIC3 PIX PIX SYS SYS SYS SYS SYS SYS SYS SYS PIX X PIX NIC4 PIX PIX SYS SYS SYS SYS SYS SYS SYS SYS PIX PIX X Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks NIC Legend: NIC0: mlx5_0 NIC1: mlx5_1 NIC2: mlx5_2 NIC3: mlx5_3 NIC4: mlx5_4 ```

🐛 Describe the bug

Running mistralai/Mixtral-8x7B-Instruct-v0.1 with 8-way TP and --enable-lora results in a crash during boot up when executing determine_num_available_blocks.

The error is:

[rank0]:   File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/_tensor_str.py", line 331, in _tensor_str
[rank0]:     self = self.float()
[rank0]:            ^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered

Example command that results in the failure:

vllm serve mistralai/Mixtral-8x7B-Instruct-v0.1 --tensor-parallel-size 8 --enable-lora
Logs with full stack trace (from one rank) ```text INFO 07-29 15:57:19 api_server.py:219] vLLM API server version 0.5.3 INFO 07-29 15:57:19 api_server.py:220] args: Namespace(model_tag='mistralai/Mixtral-8x7B-Instruct-v0.1', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='mistralai/Mixtral-8x7B-Instruct-v0.1', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=8, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=True, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None, dispatch_function=) INFO 07-29 15:57:19 config.py:724] Defaulting to use mp for distributed inference INFO 07-29 15:57:19 llm_engine.py:176] Initializing an LLM engine (v0.5.3) with config: model='mistralai/Mixtral-8x7B-Instruct-v0.1', speculative_config=None, tokenizer='mistralai/Mixtral-8x7B-Instruct-v0.1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=mistralai/Mixtral-8x7B-Instruct-v0.1, use_v2_block_manager=False, enable_prefix_caching=False) INFO 07-29 15:57:25 utils.py:784] Found nccl from library libnccl.so.2 INFO 07-29 15:57:25 pynccl.py:63] vLLM is using nccl==2.20.5 INFO 07-29 15:57:28 custom_all_reduce_utils.py:232] reading GPU P2P access cache from /home/vllm/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json INFO 07-29 15:57:28 shm_broadcast.py:241] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer=, local_subscribe_port=42929, local_sync_port=39285, remote_subscribe_port=None, remote_sync_port=None) INFO 07-29 15:57:28 model_runner.py:680] Starting to load model mistralai/Mixtral-8x7B-Instruct-v0.1... INFO 07-29 15:57:28 weight_utils.py:223] Using model weights format ['*.safetensors', '*.bin'] Loading safetensors checkpoint shards: 0% Completed | 0/19 [00:00 [rank0]: sys.exit(main()) [rank0]: ^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/scripts.py", line 148, in main [rank0]: args.dispatch_function(args) [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/scripts.py", line 28, in serve [rank0]: run_server(args) [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/entrypoints/openai/api_server.py", line 231, in run_server [rank0]: if llm_engine is not None else AsyncLLMEngine.from_engine_args( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 466, in from_engine_args [rank0]: engine = cls( [rank0]: ^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 380, in __init__ [rank0]: self.engine = self._init_engine(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 547, in _init_engine [rank0]: return engine_class(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/llm_engine.py", line 265, in __init__ [rank0]: self._initialize_kv_caches() [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/engine/llm_engine.py", line 364, in _initialize_kv_caches [rank0]: self.model_executor.determine_num_available_blocks()) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/executor/distributed_gpu_executor.py", line 38, in determine_num_available_blocks [rank0]: num_blocks = self._run_workers("determine_num_available_blocks", ) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/executor/multiproc_gpu_executor.py", line 178, in _run_workers [rank0]: driver_worker_output = driver_worker_method(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/worker/worker.py", line 179, in determine_num_available_blocks [rank0]: self.model_runner.profile_run() [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/worker/model_runner.py", line 896, in profile_run [rank0]: self.execute_model(model_input, kv_caches, intermediate_tensors) [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/worker/model_runner.py", line 1314, in execute_model [rank0]: hidden_or_intermediate_states = model_executable( [rank0]: ^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/model_executor/models/mixtral.py", line 374, in forward [rank0]: hidden_states = self.model(input_ids, positions, kv_caches, [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/model_executor/models/mixtral.py", line 296, in forward [rank0]: hidden_states, residual = layer(positions, hidden_states, [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/model_executor/models/mixtral.py", line 233, in forward [rank0]: hidden_states = self.self_attn( [rank0]: ^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/model_executor/models/mixtral.py", line 180, in forward [rank0]: output, _ = self.o_proj(attn_output) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/lora/layers.py", line 1001, in forward [rank0]: output_ = tensor_model_parallel_all_reduce(output_parallel) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/distributed/communication_op.py", line 11, in tensor_model_parallel_all_reduce [rank0]: return get_tp_group().all_reduce(input_) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/vllm/distributed/parallel_state.py", line 293, in all_reduce [rank0]: torch.distributed.all_reduce(input_, group=self.device_group) [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/distributed/c10d_logger.py", line 77, in wrapper [rank0]: msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/distributed/c10d_logger.py", line 50, in _get_msg_dict [rank0]: "args": f"{args}, {kwargs}", [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/_tensor.py", line 464, in __repr__ [rank0]: return torch._tensor_str._str(self, tensor_contents=tensor_contents) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/_tensor_str.py", line 697, in _str [rank0]: return _str_intern(self, tensor_contents=tensor_contents) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/_tensor_str.py", line 617, in _str_intern [rank0]: tensor_str = _tensor_str(self, indent) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/workspace/my-vllm/lib64/python3.11/site-packages/torch/_tensor_str.py", line 331, in _tensor_str [rank0]: self = self.float() [rank0]: ^^^^^^^^^^^^ [rank0]: RuntimeError: CUDA error: an illegal memory access was encountered [rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. INFO 07-29 15:57:37 multiproc_worker_utils.py:123] Killing local vLLM worker processes terminate called after throwing an instance of 'c10::Error' what(): CUDA error: an illegal memory access was encountered Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fa927bf2897 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fa927ba2b25 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libc10.so) frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fa927cca718 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libc10_cuda.so) frame #3: + 0x1045ef6 (0x7fa928d68ef6 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libtorch_cuda.so) frame #4: + 0x5a7380 (0x7fa973b33380 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libtorch_python.so) frame #5: + 0x6a36f (0x7fa927bd736f in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libc10.so) frame #6: c10::TensorImpl::~TensorImpl() + 0x21b (0x7fa927bd01cb in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libc10.so) frame #7: c10::TensorImpl::~TensorImpl() + 0x9 (0x7fa927bd0379 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libc10.so) frame #8: + 0x858328 (0x7fa973de4328 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libtorch_python.so) frame #9: THPVariable_subclass_dealloc(_object*) + 0x2f6 (0x7fa973de46a6 in /workspace/my-vllm/lib64/python3.11/site-packages/torch/lib/libtorch_python.so) frame #25: + 0x29590 (0x7fa975e08590 in /usr/lib64/libc.so.6) frame #26: __libc_start_main + 0x80 (0x7fa975e08640 in /usr/lib64/libc.so.6) frame #27: _start + 0x25 (0x5571031a7095 in /workspace/my-vllm/bin/python3.11) /usr/lib64/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' ```

The stack trace points at the error coming from logging in c10d_logger.py after calling torch.distributed.all_reduce, but I think the GPU memory is already corrupted at this point and the calls indicated in the stack trace are just the next place the data is accessed. In my investigation, I was able to track the source of the memory corruption to the first call to the Punica kernels at https://github.com/vllm-project/vllm/blob/v0.5.3/vllm/lora/punica.py#L136. After that call, attempts to access the data of any of the resulting tensors raises the illegal memory access error. I determined the sizes of the tensors going in to the call and was able to make a simple reproducer script (works on a single GPU):

import torch
from vllm import _custom_ops as ops

# crashes with 8192 and 16384 too
seq_len = 32768
h1 = 512
h2 = 16
with torch.device('cuda'):
  buffer = torch.zeros([seq_len, h2], dtype=torch.bfloat16)
  x = torch.zeros([seq_len, h1], dtype=torch.float32)
  wa_t_all = torch.zeros([1, 1, h2, h1], dtype=torch.bfloat16)
  indicies = torch.zeros([seq_len], dtype=torch.long)

ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, 0, 1.0)

# crashes with CUDA error: an illegal memory access was encountered
buffer.any()
print("SUCCESS")

Very similar issue reported for Mistral 7B: https://github.com/vllm-project/vllm/issues/6725

njhill commented 3 months ago

@Yard1 any idea where to start on this one? :)

jeejeelee commented 3 months ago

@tjohnson31415 I can reproduce your error,by using the above script, and then by using the compute-sanitizer , it was determined that the bgmv_shrink_kernel has an out-of-bounds issue.

========= Invalid __global__ read of size 16 bytes
=========     at 0x220 in void bgmv_shrink_kernel<(int)512, (int)16, (unsigned long)8, (unsigned long)32, (unsigned long)16, (int)32, (int)4, (int)4, float, __nv_bfloat16, __nv_bfloat16>(T10 *, const T9 *, const T11 *, const long *, long, long, long, long, float)
=========     by thread (3,2,0) in block (9,32767,0)
=========     Address 0x7fc06a000060 is out of bounds
=========     and is 97 bytes after the nearest allocation at 0x7fc066000000 of size 67108864 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!