vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.88k stars 4.7k forks source link

[Bug]: vLLM 0.5.5 and FlashInfer0.1.6 #8091

Open wlwqq opened 2 months ago

wlwqq commented 2 months ago

Your current environment

The output of `python collect_env.py` ```text Your output of `python collect_env.py` here ``` vllm == 0.5.5 FlashInfer==0.1.6+cu121torch2.4

🐛 Describe the bug

when i use vLLm0.5.5 and FlashInfer0.1.6 to run Gemma-2-2b in T4. FlashInfer0.1.6 support T4: https://github.com/flashinfer-ai/flashinfer/releases but i see:

INFO 09-02 16:07:55 model_runner.py:890] Loading model weights took 4.8999 GB
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/entrypoints/openai/rpc/server.py", line 230, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
  File "/opt/conda/lib/python3.10/site-packages/vllm/entrypoints/openai/rpc/server.py", line 31, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
  File "/opt/conda/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 740, in from_engine_args
    engine = cls(
  File "/opt/conda/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 636, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 840, in _init_engine
    return engine_class(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 272, in __init__
    super().__init__(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 284, in __init__
    self._initialize_kv_caches()
  File "/opt/conda/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 390, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
  File "/opt/conda/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 113, in determine_num_available_blocks
    return self.driver_worker.determine_num_available_blocks()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/worker/worker.py", line 222, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1097, in profile_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1415, in execute_model
    hidden_or_intermediate_states = model_executable(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/models/gemma2.py", line 342, in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/models/gemma2.py", line 281, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/models/gemma2.py", line 225, in forward
    hidden_states = self.self_attn(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/model_executor/models/gemma2.py", line 165, in forward
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/attention/layer.py", line 98, in forward
    return self.impl.forward(query,
  File "/opt/conda/lib/python3.10/site-packages/vllm/attention/backends/flashinfer.py", line 688, in forward
    output = torch.ops.vllm.flash_attn_varlen_func(
  File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 1061, in __call__
    return self_._op(*args, **(kwargs or {}))
  File "/opt/conda/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 236, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/vllm/attention/backends/flash_attn.py", line 48, in flash_attn_varlen_func
    return _flash_attn_varlen_func(
  File "/opt/conda/lib/python3.10/site-packages/vllm_flash_attn/flash_attn_interface.py", line 1154, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.10/site-packages/vllm_flash_attn/flash_attn_interface.py", line 632, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
  File "/opt/conda/lib/python3.10/site-packages/vllm_flash_attn/flash_attn_interface.py", line 90, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only supports Ampere GPUs or newer.

I'm not sure if it's a problem with vLLM's integration with FlashInfer @youkaichao @LiuXiaoxuanPKU

Before submitting a new issue...

jonzhep commented 2 months ago

RuntimeError: FlashAttention only supports Ampere GPUs or newer.

yzh119 commented 2 months ago

You are not using flashinfer backend if you encountered this error. This error message is reported by flash-attn package.

Try setting VLLM_ATTENTION_BACKEND=FLASHINFER and run the script again.

youkaichao commented 2 months ago

see https://github.com/vllm-project/vllm/issues/8189#issuecomment-2332350166

@yzh119 we do have some todos for flashinfer backend, before that, flashinfer still depends on flash attention for prefill :(

yzh119 commented 2 months ago

Hi @youkaichao , thanks for letting me know! flashinfer v0.1.7 will be fully JIT, and I'll make it a pypi package which can be set as a vllm dependency. I'll keep you posted about the progress.