mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.08k stars 1.56k forks source link

[Bug] block_infos returned as None in GEMV #614

Closed sbwww closed 1 year ago

sbwww commented 1 year ago

🐛 Bug

I encountered the following error when I was trying to add a custom relax_model and compile it.

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/root/mlc-llm/mlc_llm/build.py", line 13, in <module>
    main()
  File "/root/mlc-llm/mlc_llm/build.py", line 10, in main
    core.build_model_from_args(parsed_args)
  File "/root/mlc-llm/mlc_llm/core.py", line 450, in build_model_from_args
    build(mod, args)
  File "/root/mlc-llm/mlc_llm/core.py", line 370, in build
    mod_deploy = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod_deploy)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
TypeError: Traceback (most recent call last):
  5: TVMFuncCall
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  3: tvm::transform::Pass::operator()(tvm::IRModule) const
  2: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 64, in transform_module
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    sch = _apply_rules(func, target, self.rules, tunable=False)
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 80, in _apply_rules
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    space = rule.apply(func, target, tunable)
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/gpu/gemv.py", line 165, in apply
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    if len(block_infos) == 1:
       ^^^^^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()

In tvm.dlight.gpu.gemv.GEMV below, block_infos can be returned as None by function normalize_prim_func() and try_inline_contiguous_spatial(). Then, using len(block_infos) without checking not None result in error. It may be something wrong on my end, but I am not clear how to locate it through the error information.

    def apply(  # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
        self,
        func: tir.PrimFunc,
        target: Target,
        _: bool,
    ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
        if not isinstance(func, tir.PrimFunc):
            return None
        sch = tir.Schedule(func)
        block_infos = normalize_prim_func(sch)
        block_infos = try_inline_contiguous_spatial(sch, block_infos)
        if len(block_infos) == 1:
            epilogue = None
        elif len(block_infos) == 2:
            epilogue = block_infos[1]
            if not epilogue.is_injective():
                return None
        else:
            return None

To Reproduce

Steps to reproduce the behavior:

1. 1. 1.

Expected behavior

Environment

Additional context

Hzfengsy commented 1 year ago

Thanks for reporting it. Could you show the primfunc that meets the error?

sbwww commented 1 year ago

Is there a quick way to identify which primfunc triggers the error? I'm not very familiar with compilation, so it would be kind if you could provide some links or docs.

Hzfengsy commented 1 year ago

print the func arguments in the apply function

sbwww commented 1 year ago

I encounter this issue in RedPajama when running python3 -m mlc_llm.build --hf-path togethercomputer/RedPajama-INCITE-Chat-3B-v1 --target android --quantization q4f16_0 --use-cache 0 The output with printed primfunc is as follows

Weights exist at dist/models/RedPajama-INCITE-Chat-3B-v1, skipping download.
Using path "dist/models/RedPajama-INCITE-Chat-3B-v1" for model "RedPajama-INCITE-Chat-3B-v1"
Database paths: ['log_db/redpajama-3b-q4f16', 'log_db/rwkv-raven-7b', 'log_db/redpajama-3b-q4f32', 'log_db/dolly-v2-3b', 'log_db/rwkv-raven-1b5', 'log_db/vicuna-v1-7b', 'log_db/rwkv-raven-3b']
Target configured: opencl -keys=opencl,gpu -max_num_threads=256 -max_shared_memory_per_block=16384 -max_threads_per_block=256 -texture_spatial_limit=16384 -thread_warp_size=1
Automatically using target for weight quantization: cuda -keys=cuda,gpu -arch=sm_80 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
Start computing and quantizing weights... This may take a while.
Finish computing and quantizing weights.
Total param size: 1.4645195007324219 GB
Start storing to cache dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/params
[0710/0710] saving param_709
All finished, 51 total shards committed, record saved to dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/params/ndarray-cache.json
Finish exporting chat config to dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/params/mlc-chat-config.json
[07:24:45] /workspace/tvm/include/tvm/topi/transform.h:1076: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound[07:24:46] /workspace/tvm/include/tvm/topi/transform.h:1076: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in boundSave a cached module to dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/mod_cache_before_build.pkl.
# from tvm.script import tir as T

@T.prim_func
def main(lv29: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv30: T.Buffer((T.int64(80), T.int64(2560)), "float16"), p_lv49: T.handle, linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv2: T.handle, p_output0: T.handle):
    T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
    n = T.int64()
    lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16")
    lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16")
    p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16")
    with T.block("root"):
        T.reads()
        T.writes()
        T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32})
        decode_local = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16", scope="local")
        lv29_local = T.alloc_buffer((T.int64(320), T.int64(2560)), "uint32", scope="local")
        lv30_local = T.alloc_buffer((T.int64(80), T.int64(2560)), "float16", scope="local")
        lv49_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), "float16", scope="local")
        var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), "float16", scope="local")        for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"):
            for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"):
                for i0_i1_fused_1_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
                    for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                        for i0_i1_fused_1_2_init in range(T.int64(4)):
                            for i2_2_init in T.vectorized(T.int64(8)):
                                with T.block("NT_matmul_init"):
                                    v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
                                    v_i1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + i0_i1_fused_1_2_init)
                                    v_i2 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + i2_2_init)
                                    T.reads()
                                    T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
                                    var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float16(0)
                        for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)):
                            for ax0 in range(T.int64(1)):
                                for ax1 in T.vectorized(T.int64(8)):
                                    with T.block("lv30_local"):
                                        v0 = T.axis.spatial(T.int64(80), k_0_0 * T.int64(4) + k_0_1 + ax0)
                                        v1 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax1)
                                        T.reads(lv30[v0, v1])
                                        T.writes(lv30_local[v0, v1])
                                        lv30_local[v0, v1] = lv30[v0, v1]
                            for k_1 in range(T.int64(4)):
                                for ax0 in range(T.int64(1)):
                                    for ax1 in T.vectorized(T.int64(8)):
                                        with T.block("lv29_local"):
                                            v0 = T.axis.spatial(T.int64(320), k_0_0 * T.int64(16) + k_0_1 * T.int64(4) + k_1 + ax0)
                                            v1 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax1)
                                            T.reads(lv29[v0, v1])
                                            T.writes(lv29_local[v0, v1])
                                            lv29_local[v0, v1] = lv29[v0, v1]
                                for k_2 in range(T.int64(8)):
                                    for ax0 in range(T.int64(1)):
                                        for ax1 in T.vectorized(T.int64(8)):
                                            with T.block("decode"):
                                                v_i = T.axis.spatial(T.int64(2560), k_0_0 * T.int64(128) + k_0_1 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax0)      
                                                v_j = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax1)
                                                T.reads(lv29_local[v_i // T.int64(8), v_j], lv30_local[v_i // T.int64(32), v_j])
                                                T.writes(decode_local[v_i, v_j])
                                                decode_local[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv29_local[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv30_local[v_i // T.int64(32), v_j]
                                    for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
                                        for ax2 in T.vectorized(T.int64(1)):
                                            with T.block("lv49_pad_local"):
                                                v0 = T.axis.spatial(T.int64(1), ax0)
                                                v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + ax1)
                                                v2 = T.axis.spatial(T.int64(2560), k_0_0 * T.int64(128) + k_0_1 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2)       
                                                T.reads(lv49[v0, v1, v2])
                                                T.writes(lv49_pad_local[v0, v1, v2])
                                                lv49_pad_local[v0, v1, v2] = T.if_then_else(v1 < n, lv49[v0, v1, v2], T.float16(0))
                                    for i0_i1_fused_1_2 in range(T.int64(4)):
                                        for i2_2 in T.vectorized(T.int64(8)):
                                            with T.block("NT_matmul_update"):
                                                v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
                                                v_i1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + i0_i1_fused_1_2)
                                                v_i2 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + i2_2)
                                                v_k = T.axis.reduce(T.int64(2560), k_0_0 * T.int64(128) + k_0_1 * T.int64(32) + k_1 * T.int64(8) + k_2)
                                                T.reads(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv49_pad_local[v_i0, v_i1, v_k], decode_local[v_k, v_i2])
                                                T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
                                                var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv49_pad_local[v_i0, v_i1, v_k] * decode_local[v_k, v_i2]
                        for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
                            for ax2 in T.vectorized(T.int64(8)):
                                with T.block("var_NT_matmul_intermediate_pad_local"):
                                    v0 = T.axis.spatial(T.int64(1), ax0)
                                    v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + ax1)
                                    v2 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2)
                                    T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2], linear_bias3[v2], lv2[v0, v1, v2])
                                    T.writes(p_output0_intermediate[v0, v1, v2])
                                    if v1 < n:
                                        p_output0_intermediate[v0, v1, v2] = var_NT_matmul_intermediate_pad_local[v0, v1, v2] + linear_bias3[v2] + lv2[v0, v1, v2]  
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/root/mlc-llm/mlc_llm/build.py", line 13, in <module>
    main()
  File "/root/mlc-llm/mlc_llm/build.py", line 10, in main
    core.build_model_from_args(parsed_args)
  File "/root/mlc-llm/mlc_llm/core.py", line 450, in build_model_from_args
    build(mod, args)
  File "/root/mlc-llm/mlc_llm/core.py", line 370, in build
    mod_deploy = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod_deploy)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
TypeError: Traceback (most recent call last):
  5: TVMFuncCall
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  3: tvm::transform::Pass::operator()(tvm::IRModule) const
  2: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 64, in transform_module
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    sch = _apply_rules(func, target, self.rules, tunable=False)
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 80, in _apply_rules
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    space = rule.apply(func, target, tunable)
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/gpu/gemv.py", line 165, in apply
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    if len(block_infos) == 1:
       ^^^^^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()

However, I cannot reproduce this when compiling Llama and RWKV. I'm using mlc-ai-nightly-cu116==0.12.dev1310 and mlc-chat-nightly-cu116==0.1.dev293. It seems that the GEMV logic is modified in this version, probably related to #599 , or maybe relax_model/gpt_neox.py need a refactor.