vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.59k stars 3.9k forks source link

[Bug]: Possible data race when running Llama 405b fp8 #6767

Closed tlrmchlsmth closed 1 month ago

tlrmchlsmth commented 1 month ago

Your current environment

Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 555.42.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             128
On-line CPU(s) list:                0-127
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
Stepping:                           8
CPU max MHz:                        4100.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hfi vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          3 MiB (64 instances)
L1i cache:                          2 MiB (64 instances)
L2 cache:                           128 MiB (64 instances)
L3 cache:                           120 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-31,64-95
NUMA node1 CPU(s):                  32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] flashinfer==0.1.1+cu121torch2.3
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] onnx==1.14.1
[pip3] onnxruntime==1.18.1
[pip3] sentence-transformers==3.0.1
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[pip3] transformers==4.43.2
[pip3] triton==2.3.1
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  NV18    NV18    NV18    NV18    NV18    NV18    NV18    PIX NODE    NODE    NODE    SYS SYS SYS SYS 0-31,64-95  0       N/A
GPU1    NV18     X  NV18    NV18    NV18    NV18    NV18    NV18    NODE    PIX NODE    NODE    SYS SYS SYS SYS 0-31,64-95  0       N/A
GPU2    NV18    NV18     X  NV18    NV18    NV18    NV18    NV18    NODE    NODE    PIX NODE    SYS SYS SYS SYS 0-31,64-95  0       N/A
GPU3    NV18    NV18    NV18     X  NV18    NV18    NV18    NV18    NODE    NODE    NODE    PIX SYS SYS SYS SYS 0-31,64-95  0       N/A
GPU4    NV18    NV18    NV18    NV18     X  NV18    NV18    NV18    SYS SYS SYS SYS PIX NODE    NODE    NODE    32-63,96-127    1       N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X  NV18    NV18    SYS SYS SYS SYS NODE    PIX NODE    NODE    32-63,96-127    1       N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X  NV18    SYS SYS SYS SYS NODE    NODE    PIX NODE    32-63,96-127    1       N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X  SYS SYS SYS SYS NODE    NODE    NODE    PIX 32-63,96-127    1       N/A
NIC0    PIX NODE    NODE    NODE    SYS SYS SYS SYS  X  NODE    NODE    NODE    SYS SYS SYS SYS             
NIC1    NODE    PIX NODE    NODE    SYS SYS SYS SYS NODE     X  NODE    NODE    SYS SYS SYS SYS             
NIC2    NODE    NODE    PIX NODE    SYS SYS SYS SYS NODE    NODE     X  NODE    SYS SYS SYS SYS             
NIC3    NODE    NODE    NODE    PIX SYS SYS SYS SYS NODE    NODE    NODE     X  SYS SYS SYS SYS             
NIC4    SYS SYS SYS SYS PIX NODE    NODE    NODE    SYS SYS SYS SYS  X  NODE    NODE    NODE                
NIC5    SYS SYS SYS SYS NODE    PIX NODE    NODE    SYS SYS SYS SYS NODE     X  NODE    NODE                
NIC6    SYS SYS SYS SYS NODE    NODE    PIX NODE    SYS SYS SYS SYS NODE    NODE     X  NODE                
NIC7    SYS SYS SYS SYS NODE    NODE    NODE    PIX SYS SYS SYS SYS NODE    NODE    NODE     X              

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7

🐛 Describe the bug

I'm debugging some hard-to-repro illegal memory accesses that are happening while running fp8 llama3 405b.

I am running the following command.

TORCH_CUDA_SANITIZER=1 CUDA_LAUNCH_BLOCKING=1  \
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 \
    --disable-log-requests --port 8192 -tp 8 --enable-chunked-prefill --max-num-batched-tokens 2048 \
    --max-num-seqs 128 --tokenizer-pool-size 2 --disable-custom-all-reduce --max-model-len 16384

If I set CUDA_LAUNCH_BLOCKING=1 and TORCH_CUDA_SANITIZER=1, I get warnings about a about a possible data race in the following log.

(VllmWorkerProcess pid=456320) ============================
(VllmWorkerProcess pid=456320) CSAN detected a possible data race on tensor with data pointer 134551410147328
(VllmWorkerProcess pid=456320) Access by stream 106208604037440 during kernel:
(VllmWorkerProcess pid=456320) aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
(VllmWorkerProcess pid=456320) writing to argument(s) self, and to the output
(VllmWorkerProcess pid=456320) With stack trace:
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
(VllmWorkerProcess pid=456320)     return _run_code(code, main_globals, None,
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
(VllmWorkerProcess pid=456320)     exec(code, run_globals)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/entrypoints/openai/api_server.py", line 317, in <module>
(VllmWorkerProcess pid=456320)     run_server(args)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/entrypoints/openai/api_server.py", line 231, in run_server
(VllmWorkerProcess pid=456320)     if llm_engine is not None else AsyncLLMEngine.from_engine_args(
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/engine/async_llm_engine.py", line 466, in from_engine_args
(VllmWorkerProcess pid=456320)     engine = cls(
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/engine/async_llm_engine.py", line 380, in __init__
(VllmWorkerProcess pid=456320)     self.engine = self._init_engine(*args, **kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/engine/async_llm_engine.py", line 547, in _init_engine
(VllmWorkerProcess pid=456320)     return engine_class(*args, **kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/engine/llm_engine.py", line 251, in __init__
(VllmWorkerProcess pid=456320)     self.model_executor = executor_class(
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/executor/multiproc_gpu_executor.py", line 201, in __init__
(VllmWorkerProcess pid=456320)     super().__init__(*args, **kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/executor/distributed_gpu_executor.py", line 25, in __init__
(VllmWorkerProcess pid=456320)     super().__init__(*args, **kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/executor/executor_base.py", line 47, in __init__
(VllmWorkerProcess pid=456320)     self._init_executor()
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/executor/multiproc_gpu_executor.py", line 89, in _init_executor
(VllmWorkerProcess pid=456320)     worker = ProcessWorkerWrapper(
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/executor/multiproc_worker_utils.py", line 162, in __init__
(VllmWorkerProcess pid=456320)     self.process.start()
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
(VllmWorkerProcess pid=456320)     self._popen = self._Popen(self)
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
(VllmWorkerProcess pid=456320)     return Popen(process_obj)
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
(VllmWorkerProcess pid=456320)     self._launch(process_obj)
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
(VllmWorkerProcess pid=456320)     code = process_obj._bootstrap(parent_sentinel=child_r)
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
(VllmWorkerProcess pid=456320)     self.run()
(VllmWorkerProcess pid=456320)   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
(VllmWorkerProcess pid=456320)     self._target(*self._args, **self._kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=456320)     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/worker/worker.py", line 220, in initialize_cache
(VllmWorkerProcess pid=456320)     self._warm_up_model()
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/worker/worker.py", line 236, in _warm_up_model
(VllmWorkerProcess pid=456320)     self.model_runner.capture_model(self.gpu_cache)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VllmWorkerProcess pid=456320)     return func(*args, **kwargs)
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/vllm/worker/model_runner.py", line 1111, in capture_model
(VllmWorkerProcess pid=456320)     block_tables=block_tables[:batch_size],
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/venv/lib/python3.10/site-packages/torch/cuda/_sanitizer.py", line 573, in __torch_dispatch__
(VllmWorkerProcess pid=456320)     errors = self.event_handler._handle_kernel_launch(
(VllmWorkerProcess pid=456320)   File "/home/tms/nm-vllm/venv/lib/python3.10/site-packages/torch/cuda/_sanitizer.py", line 374, in _handle_kernel_launch
(VllmWorkerProcess pid=456320)     stack_trace = traceback.StackSummary.extract(

This points to the following code: https://github.com/vllm-project/vllm/blob/5689e256baf0c45148a01ad147abf11ad82c9690/vllm/worker/model_runner.py#L1137-L1152

Not sure why the slicing of block_tables would be an issue here.

If you don't set CUDA_LAUNCH_BLOCKING=1, you may see an error like the following:

Error: Failed to initialize the TMA descriptor 700                                                              
[rank2]:[E ProcessGroupNCCL.cpp:1414] [PG 2 Rank 2] Process group watchdog thread terminated with exception: CUD
A error: an illegal memory access was encountered
tlrmchlsmth commented 1 month ago

Might be related to https://github.com/vllm-project/vllm/issues/4108

comaniac commented 1 month ago

Actually, the previous MoE illegal memory access error was also caused by slicing input tensors. Not sure what happens underlying...

tlrmchlsmth commented 1 month ago

This looks like a red herring, but with #6788, we can at least try to use TORCH_CUDA_SANITIZER=1 to investigate the underlying issue