sgl-project / sglang

SGLang is a fast serving framework for large language models and vision language models.
https://sglang.readthedocs.io/en/latest/
Apache License 2.0
5.52k stars 412 forks source link

[Bug] T4 not work #1058

Closed zhyncs closed 1 month ago

zhyncs commented 1 month ago

Checklist

Describe the bug

T4 not work w/o FlashInfer ref https://github.com/flashinfer-ai/flashinfer/issues/421

CUDA Error: no kernel image is available for execution on the device (209) /tmp/build-via-sdist-iemil769/flashinfer-0.1.4+cu121torch2.4/include/flashinfer/attention/handler.cuh: line 169 at function cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, num_threads, smem_size)
CUDA Error: no kernel image is available for execution on the device (209) /tmp/build-via-sdist-iemil769/flashinfer-0.1.4+cu121torch2.4/include/flashinfer/attention/handler.cuh: line 324 at function work_estimation_func(split_kv, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr_h, num_qo_heads, page_size, IsCUDAGraphEnabled(), stream_)
Process Process-1:
Initialization failed. controller_init_state: Traceback (most recent call last):
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 344, in init_cuda_graphs
    self.cuda_graph_runner.capture(batch_size_list)
  File "/content/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 148, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/content/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 183, in capture_one_batch_size
    update_flashinfer_indices(
  File "/content/sglang/python/sglang/srt/model_executor/forward_batch_info.py", line 284, in update_flashinfer_indices
    flashinfer_decode_wrapper.begin_forward(
  File "/usr/local/lib/python3.10/dist-packages/flashinfer/decode.py", line 525, in begin_forward
    self._wrapper.begin_forward(
RuntimeError: BatchDecodeWithPagedKVCache failed with error no kernel image is available for execution on the device

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/content/sglang/python/sglang/srt/managers/controller_single.py", line 150, in start_controller_process
    controller = ControllerSingle(
  File "/content/sglang/python/sglang/srt/managers/controller_single.py", line 84, in __init__
    self.tp_server = ModelTpServer(
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 100, in __init__
    self.model_runner = ModelRunner(
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 139, in __init__
    self.init_cuda_graphs()
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 346, in init_cuda_graphs
    raise Exception(
Exception: Capture cuda graph failed: BatchDecodeWithPagedKVCache failed with error no kernel image is available for execution on the device
Possible solutions:
1. disable torch compile by not using --enable-torch-compile
2. disable cuda graph by --disable-cuda-graph
3. set --mem-fraction-static to a smaller value
Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose 

Initialization failed. detoken_init_state: init ok
INFO:     127.0.0.1:59800 - "GET /get_model_info HTTP/1.1" 200 OK
[gpu=0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, cache hit rate: 0.00%, #running-req: 0, #queue-req: 0
Exception in ModelTpServer:
Traceback (most recent call last):
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 222, in exposed_step
    self.forward_step()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 238, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 452, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 397, in forward
    return self.forward_extend(batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 373, in forward_extend
    return self.model.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 287, in forward
    hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 255, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 207, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 156, in forward
    attn_output = self.attn(q, k, v, input_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/layers/radix_attention.py", line 177, in forward
    return self.extend_forward(q, k, v, input_metadata)
  File "/content/sglang/python/sglang/srt/layers/radix_attention.py", line 69, in extend_forward_triton
    extend_attention_fwd(
  File "/content/sglang/python/sglang/srt/layers/extend_attention.py", line 291, in extend_attention_fwd
    _fwd_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 318, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 216, in make_llir
    pm.run(mod)
IndexError: map::at

Exception in ControllerSingle:
Traceback (most recent call last):
  File "/content/sglang/python/sglang/srt/managers/controller_single.py", line 166, in start_controller_process
    controller.loop_for_forward()
  File "/content/sglang/python/sglang/srt/managers/controller_single.py", line 103, in loop_for_forward
    out_pyobjs = self.tp_server.exposed_step(recv_reqs)
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 222, in exposed_step
    self.forward_step()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 238, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/content/sglang/python/sglang/srt/managers/tp_worker.py", line 452, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 397, in forward
    return self.forward_extend(batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/model_executor/model_runner.py", line 373, in forward_extend
    return self.model.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 287, in forward
    hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 255, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 207, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/models/qwen2.py", line 156, in forward
    attn_output = self.attn(q, k, v, input_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/sglang/python/sglang/srt/layers/radix_attention.py", line 177, in forward
    return self.extend_forward(q, k, v, input_metadata)
  File "/content/sglang/python/sglang/srt/layers/radix_attention.py", line 69, in extend_forward_triton
    extend_attention_fwd(
  File "/content/sglang/python/sglang/srt/layers/extend_attention.py", line 291, in extend_attention_fwd
    _fwd_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 282, in compile
    next_module = compile_ir(module, metadata)
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 318, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 216, in make_llir
    pm.run(mod)
IndexError: map::at

[rank0]:W0812 14:38:42.916000 137655740245568 torch/_inductor/compile_worker/subproc_pool.py:126] SubprocPool unclean exit

Reproduction

# Use the last release branch
git clone -b v0.2.12 https://github.com/sgl-project/sglang.git
cd sglang

pip install --upgrade pip
pip install -e "python[all]"

# Install FlashInfer CUDA kernels
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/

python3 -m sglang.launch_server --model Qwen/Qwen1.5-4B-Chat

python3 -m sglang.launch_server --model Qwen/Qwen1.5-4B-Chat --disable-flashinfer --disable-flashinfer-sampling

Environment

Python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
CUDA available: True
GPU 0: Tesla T4
GPU 0 Compute Capability: 7.5
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.2, V12.2.140
CUDA Driver Version: 535.104.05
PyTorch: 2.4.0+cu121
sglang: 0.2.12
flashinfer: 0.1.4+cu121torch2.4
triton: 3.0.0
transformers: 4.44.0
requests: 2.32.3
tqdm: 4.66.5
numpy: 1.26.4
aiohttp: 3.10.1
fastapi: 0.112.0
hf_transfer: 0.1.8
huggingface_hub: 0.23.5
interegular: 0.3.3
packaging: 24.1
PIL: 9.4.0
psutil: 5.9.5
pydantic: 2.8.2
uvicorn: 0.30.5
uvloop: 0.19.0
zmq: 24.0.1
vllm: 0.5.4
multipart: 0.0.9
openai: 1.40.3
anthropic: 0.33.0
NVIDIA Topology: 
    GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  0-1     N/A     N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

ulimit soft: 1048576
zhyncs commented 1 month ago

This verification was completed on Google Colab. T4 is sm75 and has 16G VRAM. Even if it can run, if it's in fp16, it can only run models below 7b, and can't even run Llama 3.1 8B. After disabling FlashInfer, using Triton still results in errors. Trying the nightly version of Triton was also ineffective.

zhyncs commented 1 month ago

cc @ispobock Perhaps you could help take a look at this issue.

ispobock commented 1 month ago

@zhyncs Try to specify the --dtype as float16 for T4.

ref: https://github.com/state-spaces/mamba/issues/361#issuecomment-2181263738

zhyncs commented 1 month ago

@zhyncs Try to specify the --dtype as float16 for T4.

ref: state-spaces/mamba#361 (comment)

Interesting workaround. I'll support sm75 on FlashInfer. It's just a test branch for now https://github.com/flashinfer-ai/flashinfer/compare/main...sm75

zhyncs commented 1 month ago

ref https://github.com/flashinfer-ai/flashinfer/actions/runs/10390772247/job/28772019389

jeejeelee commented 1 month ago

It might be due to bf16. SM75 doesn't support bf16.

zhyncs commented 1 month ago

ref https://github.com/sgl-project/sglang/pull/1136 https://github.com/flashinfer-ai/flashinfer/pull/448 https://github.com/flashinfer-ai/flashinfer/pull/449

zhyncs commented 1 month ago

fixed with https://github.com/sgl-project/sglang/pull/1233

Thanks @yzh119 's support!