HabanaAI / vllm-fork

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
34 stars 37 forks source link

[Bug]: Habana_NEXT failed on Lazy Model + Tensor Parallel - [Quick Fix Provided] #173

Open xuechendi opened 1 month ago

xuechendi commented 1 month ago

πŸ› Describe the bug

PR 115 - 969bd835 Do reshape_and_cache in bulk for prompt (#115) will result in HPUGraph forward replayV3 failure in TensorParallel mode.

=== Details ===

from vllm import LLM, SamplingParams

def test_multiple_sampling_params():

    llm = LLM(model="meta-llama/Meta-Llama-3-8B",
              tensor_parallel_size=2, 
              block_size=128, 
              max_num_seqs=128
             )

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    sampling_params = [
        SamplingParams(temperature=1.0, top_p=1.0, use_beam_search=False, max_tokens=128)
        for _ in prompts
    ]

    # Multiple SamplingParams should be matched with each prompt
    output = llm.generate(prompts, sampling_params=sampling_params)

    print("**Test Completed, output as below**")
    print(output)

if __name__ == "__main__":
    test_multiple_sampling_params()

tests/entrypoints/test_llm_generate_TP.py::test_multiple_sampling_params 2024-08-12 19:18:49,602 INFO worker.py:1749 -- Started a local Ray instance. INFO 08-12 19:18:50 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='meta-llama/Meta-Llama-3-8B', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, weights_load_device=hpu, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=hpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B) INFO 08-12 19:18:55 profiler.py:59] Profiler enabled for: vllm-instance-3facac1d86db4625a7b0f72d4b6fe35e WARNING 08-12 19:18:55 utils.py:455] Pin memory is not supported on HPU. INFO 08-12 19:18:55 selector.py:52] Using HabanaAttention backend. .. INFO 08-12 19:18:55 habana_model_runner.py:463] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-12 19:18:55 habana_model_runner.py:464] Decode bucket config (min, step, max_warmup) bs:[32, 32, 128], seq:[128, 128, 2048] ============================= HABANA PT BRIDGE CONFIGURATION =========================== PT_HPU_LAZY_MODE = 1 PT_RECIPE_CACHE_PATH = PT_CACHE_FOLDER_DELETE = 0 PT_HPU_RECIPE_CACHE_CONFIG = PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807 PT_HPU_LAZY_ACC_PAR_MODE = 1 PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0 ---------------------------: System Configuration :--------------------------- Num CPU Cores : 152 CPU RAM : 1056440360 KB

INFO 08-12 19:19:01 loader.py:226] Loading weights on hpu ... INFO 08-12 19:19:01 weight_utils.py:199] Using model weights format ['*.safetensors'] INFO 08-12 19:19:03 habana_model_runner.py:407] Pre-loading model weights on hpu:0 took 7.483 GiB of device memory (7.488 GiB/94.62 GiB used) and 315.3 MiB of host memory (123.8 GiB/1008 GiB used) INFO 08-12 19:19:03 habana_model_runner.py:428] Loading model weights took in total 7.483 GiB of device memory (7.488 GiB/94.62 GiB used) and 516.1 MiB of host memory (123.9 GiB/1008 GiB used) INFO 08-12 19:19:04 pynccl_utils.py:17] Failed to import NCCL library: NCCL only supports CUDA and ROCm backends. INFO 08-12 19:19:04 pynccl_utils.py:18] It is expected if you are not running on NVIDIA GPUs. SIGSEGV received at time=1723490346 on cpu 67 PC: @ 0xd (unknown) (unknown) @ 0x7f0a980ee520 1691771024 (unknown) @ 0x7f09cbcce59d (unknown) habana::ComputePermutationHashCode() [2024-08-12 19:19:06,813 E 3021389 3021389] logging.cc:365: SIGSEGV received at time=1723490346 on cpu 67 [2024-08-12 19:19:06,815 E 3021389 3021389] logging.cc:365: PC: @ 0xd (unknown) (unknown) [2024-08-12 19:19:06,817 E 3021389 3021389] logging.cc:365: @ 0x7f0a980ee520 1691771024 (unknown) [2024-08-12 19:19:06,817 E 3021389 3021389] logging.cc:365: @ 0x7f09cbcce59d (unknown) habana::ComputePermutationHashCode() Fatal Python error: Segmentation fault

Thread 0x00007eda97552640 (most recent call first): File "/usr/lib/python3.10/threading.py", line 324 in wait File "/usr/lib/python3.10/threading.py", line 607 in wait File "/usr/local/lib/python3.10/dist-packages/tqdm/_monitor.py", line 60 in run File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Thread 0x00007f0a065cd640 (most recent call first): File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2136 in listen_error_messages File "/usr/lib/python3.10/threading.py", line 953 in run File "/usr/lib/python3.10/threading.py", line 1016 in _bootstrap_inner File "/usr/lib/python3.10/threading.py", line 973 in _bootstrap

Current thread 0x00007f0a97e15800 (most recent call first): File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 77 in replayV3 File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 639 in wrapped_hpugraph_forward File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 742 in forward File "/root/vllm/vllm/worker/habana_model_runner.py", line 1027 in execute_model File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115 in decorate_context File "/root/vllm/vllm/worker/habana_model_runner.py", line 1156 in warmup_scenario File "/root/vllm/vllm/worker/habana_model_runner.py", line 1135 in profile_run File "/root/vllm/vllm/worker/habana_worker.py", line 139 in determine_num_available_blocks File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115 in decorate_context File "/root/vllm/vllm/worker/worker_base.py", line 137 in execute_method File "/root/vllm/vllm/executor/ray_habana_executor.py", line 217 in _run_workers File "/root/vllm/vllm/executor/distributed_gpu_executor.py", line 27 in determine_num_available_blocks File "/root/vllm/vllm/engine/llm_engine.py", line 250 in _initialize_kv_caches File "/root/vllm/vllm/engine/llm_engine.py", line 173 in init File "/root/vllm/vllm/engine/llm_engine.py", line 300 in from_engine_args File "/root/vllm/vllm/entrypoints/llm.py", line 123 in init File "/root/vllm/tests/entrypoints/test_llm_generate_TP.py", line 8 in test_multiple_sampling_params File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 159 in pytest_pyfunc_call File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in call File "/usr/local/lib/python3.10/dist-packages/_pytest/python.py", line 1627 in runtest File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 174 in pytest_runtest_call File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in call File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 242 in File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 341 in from_call File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 241 in call_and_report File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 132 in runtestprotocol File "/usr/local/lib/python3.10/dist-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in call File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 362 in pytest_runtestloop File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in call File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 337 in _main File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 283 in wrap_session File "/usr/local/lib/python3.10/dist-packages/_pytest/main.py", line 330 in pytest_cmdline_main File "/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py", line 103 in _multicall File "/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py", line 120 in _hookexec File "/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py", line 513 in call File "/usr/local/lib/python3.10/dist-packages/_pytest/config/init.py", line 175 in main File "/usr/local/lib/python3.10/dist-packages/_pytest/config/init.py", line 201 in console_main File "/usr/local/bin/pytest", line 8 in

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, PIL._imaging, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, yaml._yaml, sentencepiece._sentencepiece, PIL._imagingft, av._core, av.logging, av.bytesource, av.buffer, av.audio.format, av.enum, av.error, av.utils, av.option, av.descriptor, av.container.pyio, av.dictionary, av.format, av.stream, av.container.streams, av.sidedata.motionvectors, av.sidedata.sidedata, av.packet, av.container.input, av.container.output, av.container.core, av.codec.context, av.video.format, av.video.reformatter, av.plane, av.video.plane, av.video.frame, av.video.stream, av.codec.codec, av.frame, av.audio.layout, av.audio.plane, av.audio.frame, av.audio.stream, av.audio.fifo, av.filter.pad, av.filter.link, av.filter.context, av.filter.graph, av.filter.filter, av.audio.resampler, av.bitstream, psutil._psutil_linux, psutil._psutil_posix, msgpack._cmsgpack, google.protobuf.pyext._message, setproctitle, uvloop.loop, ray._raylet (total: 76) Internal Error: Received signal - Segmentation fault


* error log 2 - HPU replay caused incorrect input_size

/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:366: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead warnings.warn( 2024-08-12 19:53:36,830 INFO worker.py:1749 -- Started a local Ray instance. INFO 08-12 19:53:38 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='meta-llama/Meta-Llama-3-8B', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, weights_load_device=hpu, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=hpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B) INFO 08-12 19:53:43 profiler.py:59] Profiler enabled for: vllm-instance-50d3be4640a0485eae00c926beda2b42 WARNING 08-12 19:53:43 utils.py:455] Pin memory is not supported on HPU. INFO 08-12 19:53:43 selector.py:52] Using HabanaAttention backend. INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_PROMPT_BS_BUCKET_MIN=1 (default:1) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_PROMPT_BS_BUCKET_STEP=32 (default:32) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_PROMPT_BS_BUCKET_MAX=64 (default:64) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_DECODE_BS_BUCKET_MIN=32 (default:32) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_DECODE_BS_BUCKET_STEP=32 (default:32) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_DECODE_BS_BUCKET_MAX=128 (default:128) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_PROMPT_SEQ_BUCKET_MIN=128 (default:128) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_PROMPT_SEQ_BUCKET_STEP=128 (default:128) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_PROMPT_SEQ_BUCKET_MAX=1024 (default:1024) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_DECODE_BLOCK_BUCKET_MIN=128 (default:128) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_DECODE_BLOCK_BUCKET_STEP=128 (default:128) INFO 08-12 19:53:43 habana_model_runner.py:138] VLLM_DECODE_BLOCK_BUCKET_MAX=2048 (default:2048) INFO 08-12 19:53:43 habana_model_runner.py:459] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 64], seq:[128, 128, 1024] INFO 08-12 19:53:43 habana_model_runner.py:460] Decode bucket config (min, step, max_warmup) bs:[32, 32, 128], seq:[128, 128, 2048] ============================= HABANA PT BRIDGE CONFIGURATION =========================== PT_HPU_LAZY_MODE = 1 PT_RECIPE_CACHE_PATH = PT_CACHE_FOLDER_DELETE = 0 PT_HPU_RECIPE_CACHE_CONFIG = PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807 PT_HPU_LAZY_ACC_PAR_MODE = 1 PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0 ---------------------------: System Configuration :--------------------------- Num CPU Cores : 152 CPU RAM : 1056440360 KB

INFO 08-12 19:53:48 loader.py:226] Loading weights on hpu ... INFO 08-12 19:53:48 weight_utils.py:199] Using model weights format ['*.safetensors'] INFO 08-12 19:53:50 habana_model_runner.py:403] Pre-loading model weights on hpu:0 took 7.483 GiB of device memory (7.488 GiB/94.62 GiB used) and 291.8 MiB of host memory (123.9 GiB/1008 GiB used) INFO 08-12 19:53:51 habana_model_runner.py:424] Loading model weights took in total 7.483 GiB of device memory (7.488 GiB/94.62 GiB used) and 479.7 MiB of host memory (124.1 GiB/1008 GiB used) INFO 08-12 19:53:51 pynccl_utils.py:17] Failed to import NCCL library: NCCL only supports CUDA and ROCm backends. INFO 08-12 19:53:51 pynccl_utils.py:18] It is expected if you are not running on NVIDIA GPUs. HPUGraph capture, forward took 2.2133708000183105 seconds INFO 08-12 19:54:01 distributed_gpu_executor.py:45] # GPU blocks: 5837, # CPU blocks: 512 INFO 08-12 19:54:01 habana_model_runner.py:1217] Skipping warmup...

Processed prompts: 0%| | 0/4 [00:00<?, ?it/s]WARNING 08-12 19:54:01 habana_model_runner.py:937] Configuration: (prompt, 4, 128) was not warmed-up! HPUGraph capture, forward took 0.3346521854400635 seconds WARNING 08-12 19:54:02 habana_model_runner.py:937] Configuration: (decode, 32, 128) was not warmed-up! HPUGraph capture, forward took 0.42697668075561523 seconds ERROR 08-12 19:54:02 worker_base.py:145] Error executing method execute_model. This might cause deadlock in distributed execution. ERROR 08-12 19:54:02 worker_base.py:145] Traceback (most recent call last): ERROR 08-12 19:54:02 worker_base.py:145] File "/root/vllm/vllm/worker/worker_base.py", line 137, in execute_method ERROR 08-12 19:54:02 worker_base.py:145] return executor(*args, kwargs) ERROR 08-12 19:54:02 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 08-12 19:54:02 worker_base.py:145] return func(*args, *kwargs) ERROR 08-12 19:54:02 worker_base.py:145] File "/root/vllm/vllm/worker/habana_worker.py", line 244, in execute_model ERROR 08-12 19:54:02 worker_base.py:145] output = self.model_runner.execute_model(seq_group_metadata_list, ERROR 08-12 19:54:02 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context ERROR 08-12 19:54:02 worker_base.py:145] return func(args, kwargs) ERROR 08-12 19:54:02 worker_base.py:145] File "/root/vllm/vllm/worker/habana_model_runner.py", line 1023, in execute_model ERROR 08-12 19:54:02 worker_base.py:145] hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices) ERROR 08-12 19:54:02 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 742, in forward ERROR 08-12 19:54:02 worker_base.py:145] return wrapped_hpugraph_forward( ERROR 08-12 19:54:02 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 639, in wrapped_hpugraph_forward ERROR 08-12 19:54:02 worker_base.py:145] graph.replayV3(input_tensor_list, asynchronous) ERROR 08-12 19:54:02 worker_base.py:145] File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 77, in replayV3 ERROR 08-12 19:54:02 worker_base.py:145] _hpu_C.replayV3(self.hpu_graph, tlistI, asynchronous) ERROR 08-12 19:54:02 worker_base.py:145] RuntimeError: Input stack size=1 is not matching with #graph_inputs=3 rank0: Traceback (most recent call last): rank0: File "/root/vllm/test_llm_generate_TP.py", line 30, in

rank0: File "/root/vllm/test_llm_generate_TP.py", line 24, in test_multiple_sampling_params rank0: output = llm.generate(prompts, sampling_params=sampling_params) rank0: File "/root/vllm/vllm/entrypoints/llm.py", line 222, in generate rank0: return self._run_engine(use_tqdm) rank0: File "/root/vllm/vllm/entrypoints/llm.py", line 250, in _run_engine rank0: step_outputs = self.llm_engine.step() rank0: File "/root/vllm/vllm/engine/llm_engine.py", line 606, in step rank0: output = self.model_executor.execute_model( rank0: File "/root/vllm/vllm/executor/ray_habana_executor.py", line 155, in execute_model rank0: all_outputs = self._run_workers( rank0: File "/root/vllm/vllm/executor/ray_habana_executor.py", line 217, in _run_workers rank0: driver_worker_output = self.driver_worker.execute_method( rank0: File "/root/vllm/vllm/worker/worker_base.py", line 146, in execute_method rank0: raise e rank0: File "/root/vllm/vllm/worker/worker_base.py", line 137, in execute_method rank0: return executor(*args, kwargs) rank0: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context rank0: return func(*args, *kwargs) rank0: File "/root/vllm/vllm/worker/habana_worker.py", line 244, in execute_model rank0: output = self.model_runner.execute_model(seq_group_metadata_list, rank0: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context rank0: return func(args, kwargs) rank0: File "/root/vllm/vllm/worker/habana_model_runner.py", line 1023, in execute_model rank0: hidden_states = self.model.forward(**execute_model_kwargs, selected_token_indices=sampling_metadata.selected_token_indices) rank0: File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 742, in forward rank0: return wrapped_hpugraph_forward( rank0: File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 639, in wrapped_hpugraph_forward rank0: graph.replayV3(input_tensor_list, asynchronous) rank0: File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/graphs.py", line 77, in replayV3 rank0: _hpu_C.replayV3(self.hpu_graph, tlistI, asynchronous) rank0: RuntimeError: Input stack size=1 is not matching with #graph_inputs=3

Processed prompts: 0%| | 0/4 [00:03<?, ?it/s]

xuechendi commented 1 month ago

@madamczykhabana, please help to take a look of this issue,

I found a way to quick fix this segment fault issue by

  1. revert PR115
  2. add Tensor Materialize before forward to HPUGraph. image

However, I think this is just a temp walkaround, would like to get your insight here. Thanks