mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
17.72k stars 1.41k forks source link

[Bug] FlashInfer decode BeginForward error an illegal instruction was encountered #2509

Open zifeitong opened 1 month ago

zifeitong commented 1 month ago

πŸ› Bug

I am seeing illegal instruction error when running mlc-llm with H100 (Driver Version: 550.54.1 CUDA Version: 12.4). It works fine on another A100 machine (Driver Version: 535.161.08, CUDA Version: 12.2).

To Reproduce

Steps to reproduce the behavior:

  1. python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu122 mlc-ai-nightly-cu122
  2. mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/var/data/home/.venv/lib/python3.10/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/workspace/mlc-llm/cpp/serve/threaded_engine.cc", line 169, in mlc::llm::serve::ThreadedEngineImpl::RunBackgroundLoop()
  File "/workspace/mlc-llm/cpp/serve/engine.cc", line 611, in mlc::llm::serve::EngineImpl::Step()
  File "/workspace/mlc-llm/cpp/serve/engine_actions/batch_decode.cc", line 107, in mlc::llm::serve::BatchDecodeActionObj::Step(mlc::llm::serve::EngineState)
  File "/workspace/mlc-llm/cpp/serve/model.cc", line 380, in mlc::llm::serve::ModelImpl::BatchDecode(tvm::runtime::ObjectRef const&, std::vector<long, std::allocator<long> > const&)
tvm._ffi.base.TVMError: Traceback (most recent call last):
  16: mlc::llm::serve::ThreadedEngineImpl::RunBackgroundLoop()
        at /workspace/mlc-llm/cpp/serve/threaded_engine.cc:169
  15: mlc::llm::serve::EngineImpl::Step()
        at /workspace/mlc-llm/cpp/serve/engine.cc:611
  14: mlc::llm::serve::BatchDecodeActionObj::Step(mlc::llm::serve::EngineState)
        at /workspace/mlc-llm/cpp/serve/engine_actions/batch_decode.cc:107
  13: mlc::llm::serve::ModelImpl::BatchDecode(tvm::runtime::ObjectRef const&, std::vector<long, std::allocator<long> > const&)
        at /workspace/mlc-llm/cpp/serve/model.cc:380
  12: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  10: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)
  9: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
  8: tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)
  7: _ZN3tvm7runtime13Packed
  6: tvm::runtime::TypedPackedFunc<void (tvm::runtime::relax_vm::AttentionKVCache, long, double, tvm::runtime::NDArray, tvm::runtime::NDArray)>::AssignTypedLambda<tvm::runtime::relax_vm::__mk_TVM17::{lambda(tvm::runtime::relax_vm::AttentionKVCache, long, double, tvm::runtime::NDArray, tvm::runtime::NDArray)#1}>(tvm::runtime::relax_vm::__mk_TVM17::{lambda(tvm::runtime::relax_vm::AttentionKVCache, long, double, tvm::runtime::NDArray, tvm::runtime::NDArray)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const [clone .constprop.0]
  5: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::AttentionWithFusedQKV(long, tvm::runtime::NDArray, tvm::runtime::Optional<tvm::runtime::NDArray>, tvm::runtime::NDArray, double)
  4: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::KernelBeginForward()
  3: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  2: tvm::runtime::TypedPackedFunc<void (long, DLTensor*, DLTensor*, DLTensor*, long, long, long, long, long, void*)>::AssignTypedLambda<void (*)(long, DLTensor*, DLTensor*, DLTensor*, long, long, long, long, long, void*)>(void (*)(long, DLTensor*, DLTensor*, DLTensor*, long, long, long, long, long, void*), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  1: _FlashInferAttentionDecodeWithPagedKVCacheBeginForward(long, DLTensor*, DLTensor*, DLTensor*, long, long, long, long, long, void*)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/3rdparty/flashinfer/src/tvm_wrapper.cu", line 437
TVMError: FlashInfer decode BeginForward error an illegal instruction was encountered

Expected behavior

Environment

Thanks!

yzh119 commented 1 month ago

Thanks for reporting this issue, I'll take a look.

yzh119 commented 3 weeks ago

I tried running inference code again with CUDA_LAUNCH_BLOCKING=1 and I got

tvm._ffi.base.TVMError: Traceback (most recent call last):
  2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_6detail17PackFuncVoidAddr_ILi4ENS0_15CUDAWrappedFuncEEENS0_10PackedFuncET0_RKSt6vectorINS4_1
  1: tvm::runtime::CUDAWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*, void**) const [clone .isra.0]
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/runtime/cuda/cuda_module.cc", line 206
TVMError: CUDALaunch Error: CUDA_ERROR_ILLEGAL_INSTRUCTION
 grid=(3072,1,1),  block=(64,4,1)

and the grid dim and block dim doesn't need to be a flashinfer kernel.

Need more investigation.

yzh119 commented 3 weeks ago

Seems grid=(3072,1,1), block=(64,4,1) was the kernel configuration of TIR NT_matmul:

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0], metadata["runtime.Module"][1], metadata["runtime.Module"][2], metadata["runtime.Module"][3], metadata["runtime.Module"][4], metadata["runtime.Module"][5], metadata["runtime.Module"][6], metadata["runtime.Module"][7], metadata["runtime.Module"][8], metadata["runtime.Module"][9], metadata["runtime.Module"][10], metadata["runtime.Module"][11]]})
    @T.prim_func(private=True)
    def NT_matmul(rms_norm65: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), model_layers_0_self_attn_qkv_proj_weight3: T.Buffer((T.int64(12288), T.int64(4096)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(12288)), "float16")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(12288)), "float16", scope="local")
        NT_matmul_rf_local_1 = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(12288)), "float16", scope="local")
        model_layers_0_self_attn_qkv_proj_weight3_local = T.alloc_buffer((T.int64(12288), T.int64(4096)), "float16", scope="local")
        rms_norm65_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared")
        for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(3072), thread="blockIdx.x"):
            for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        for ax2_0 in T.serial(T.int64(2), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
                            for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                                for ax2_2 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                                    for ax2_3 in T.vectorized(T.int64(8)):
                                        with T.block("rms_norm65_shared"):
                                            v0, v1 = T.axis.remap("SS", [ax0, ax1])
                                            v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(2048) + ax2_1 * T.int64(512) + ax2_2 * T.int64(8) + ax2_3)
                                            T.reads(rms_norm65[v0, v1, v2])
                                            T.writes(rms_norm65_shared[v0, v1, v2])
                                            rms_norm65_shared[v0, v1, v2] = rms_norm65[v0, v1, v2]
                    for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
                        for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init)
                                v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
                                T.reads()
                                T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0])
                                NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0)
                    for ax1_fused_u_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax0_ax1_fused_0 in range(T.int64(4)):
                            for ax0_ax1_fused_1 in T.vectorized(T.int64(2)):
                                with T.block("model_layers_0_self_attn_qkv_proj_weight3_local"):
                                    v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1)
                                    v1 = T.axis.spatial(T.int64(4096), ax1_fused_u_fused_0 * T.int64(512) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1)
                                    T.reads(model_layers_0_self_attn_qkv_proj_weight3[v0, v1])
                                    T.writes(model_layers_0_self_attn_qkv_proj_weight3_local[v0, v1])
                                    model_layers_0_self_attn_qkv_proj_weight3_local[v0, v1] = model_layers_0_self_attn_qkv_proj_weight3[v0, v1]
                        for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)):
                            for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1)
                                    v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
                                    vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2])
                                    T.reads(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], rms_norm65_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_layers_0_self_attn_qkv_proj_weight3_local[v0, vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)])
                                    T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0])
                                    NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + rms_norm65_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_layers_0_self_attn_qkv_proj_weight3_local[v0, vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]
            for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                    for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        for ax2_fused_2_1 in T.vectorized(T.int64(1)):
                            with T.block("NT_matmul_rf_init"):
                                vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0)
                                v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                T.reads()
                                T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0)
                            for ax1 in range(T.int64(4)):
                                with T.block("NT_matmul_rf_update"):
                                    vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
                                    v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
                                    T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0])
                                    T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                                    NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]
            for ax1_fused_2 in range(T.int64(1)):
                for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
                        with T.block("NT_matmul"):
                            vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0)
                            v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
                            T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0])
                            T.writes(NT_matmul[T.int64(0), T.int64(0), v0])
                            with T.init():
                                NT_matmul[T.int64(0), T.int64(0), v0] = T.float16(0)
                            NT_matmul[T.int64(0), T.int64(0), v0] = NT_matmul[T.int64(0), T.int64(0), v0] + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]
MasterJH5574 commented 3 weeks ago

Thank you @yzh119 so much for triaging and locate the issue πŸ™ We'll further look into it and report back for further updates.

avianion commented 2 weeks ago

Exact same issue here.