Closed sbwww closed 1 year ago
Thanks for reporting it. Could you show the primfunc that meets the error?
Is there a quick way to identify which primfunc triggers the error? I'm not very familiar with compilation, so it would be kind if you could provide some links or docs.
print the func
arguments in the apply
function
I encounter this issue in RedPajama when running python3 -m mlc_llm.build --hf-path togethercomputer/RedPajama-INCITE-Chat-3B-v1 --target android --quantization q4f16_0 --use-cache 0
The output with printed primfunc is as follows
Weights exist at dist/models/RedPajama-INCITE-Chat-3B-v1, skipping download.
Using path "dist/models/RedPajama-INCITE-Chat-3B-v1" for model "RedPajama-INCITE-Chat-3B-v1"
Database paths: ['log_db/redpajama-3b-q4f16', 'log_db/rwkv-raven-7b', 'log_db/redpajama-3b-q4f32', 'log_db/dolly-v2-3b', 'log_db/rwkv-raven-1b5', 'log_db/vicuna-v1-7b', 'log_db/rwkv-raven-3b']
Target configured: opencl -keys=opencl,gpu -max_num_threads=256 -max_shared_memory_per_block=16384 -max_threads_per_block=256 -texture_spatial_limit=16384 -thread_warp_size=1
Automatically using target for weight quantization: cuda -keys=cuda,gpu -arch=sm_80 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
Start computing and quantizing weights... This may take a while.
Finish computing and quantizing weights.
Total param size: 1.4645195007324219 GB
Start storing to cache dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/params
[0710/0710] saving param_709
All finished, 51 total shards committed, record saved to dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/params/ndarray-cache.json
Finish exporting chat config to dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/params/mlc-chat-config.json
[07:24:45] /workspace/tvm/include/tvm/topi/transform.h:1076: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound[07:24:46] /workspace/tvm/include/tvm/topi/transform.h:1076: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in boundSave a cached module to dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0/mod_cache_before_build.pkl.
# from tvm.script import tir as T
@T.prim_func
def main(lv29: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv30: T.Buffer((T.int64(80), T.int64(2560)), "float16"), p_lv49: T.handle, linear_bias3: T.Buffer((T.int64(2560),), "float16"), p_lv2: T.handle, p_output0: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
n = T.int64()
lv49 = T.match_buffer(p_lv49, (T.int64(1), n, T.int64(2560)), "float16")
lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(2560)), "float16")
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16")
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.thread_extent_low_inclusive": 32})
decode_local = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16", scope="local")
lv29_local = T.alloc_buffer((T.int64(320), T.int64(2560)), "uint32", scope="local")
lv30_local = T.alloc_buffer((T.int64(80), T.int64(2560)), "float16", scope="local")
lv49_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), "float16", scope="local")
var_NT_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(2560)), "float16", scope="local") for i0_i1_fused_0_i0_i1_fused_1_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32), thread="blockIdx.y"):
for i2_0 in T.thread_binding(T.int64(20), thread="blockIdx.x"):
for i0_i1_fused_1_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"):
for i2_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for i0_i1_fused_1_2_init in range(T.int64(4)):
for i2_2_init in T.vectorized(T.int64(8)):
with T.block("NT_matmul_init"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + i0_i1_fused_1_2_init)
v_i2 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + i2_2_init)
T.reads()
T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float16(0)
for k_0_0, k_0_1 in T.grid(T.int64(20), T.int64(4)):
for ax0 in range(T.int64(1)):
for ax1 in T.vectorized(T.int64(8)):
with T.block("lv30_local"):
v0 = T.axis.spatial(T.int64(80), k_0_0 * T.int64(4) + k_0_1 + ax0)
v1 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax1)
T.reads(lv30[v0, v1])
T.writes(lv30_local[v0, v1])
lv30_local[v0, v1] = lv30[v0, v1]
for k_1 in range(T.int64(4)):
for ax0 in range(T.int64(1)):
for ax1 in T.vectorized(T.int64(8)):
with T.block("lv29_local"):
v0 = T.axis.spatial(T.int64(320), k_0_0 * T.int64(16) + k_0_1 * T.int64(4) + k_1 + ax0)
v1 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax1)
T.reads(lv29[v0, v1])
T.writes(lv29_local[v0, v1])
lv29_local[v0, v1] = lv29[v0, v1]
for k_2 in range(T.int64(8)):
for ax0 in range(T.int64(1)):
for ax1 in T.vectorized(T.int64(8)):
with T.block("decode"):
v_i = T.axis.spatial(T.int64(2560), k_0_0 * T.int64(128) + k_0_1 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax0)
v_j = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax1)
T.reads(lv29_local[v_i // T.int64(8), v_j], lv30_local[v_i // T.int64(32), v_j])
T.writes(decode_local[v_i, v_j])
decode_local[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv29_local[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv30_local[v_i // T.int64(32), v_j]
for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
for ax2 in T.vectorized(T.int64(1)):
with T.block("lv49_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(2560), k_0_0 * T.int64(128) + k_0_1 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2)
T.reads(lv49[v0, v1, v2])
T.writes(lv49_pad_local[v0, v1, v2])
lv49_pad_local[v0, v1, v2] = T.if_then_else(v1 < n, lv49[v0, v1, v2], T.float16(0))
for i0_i1_fused_1_2 in range(T.int64(4)):
for i2_2 in T.vectorized(T.int64(8)):
with T.block("NT_matmul_update"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + i0_i1_fused_1_2)
v_i2 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + i2_2)
v_k = T.axis.reduce(T.int64(2560), k_0_0 * T.int64(128) + k_0_1 * T.int64(32) + k_1 * T.int64(8) + k_2)
T.reads(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2], lv49_pad_local[v_i0, v_i1, v_k], decode_local[v_k, v_i2])
T.writes(var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + lv49_pad_local[v_i0, v_i1, v_k] * decode_local[v_k, v_i2]
for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
for ax2 in T.vectorized(T.int64(8)):
with T.block("var_NT_matmul_intermediate_pad_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0_i0_i1_fused_1_0_fused * T.int64(32) + i0_i1_fused_1_1 * T.int64(4) + ax1)
v2 = T.axis.spatial(T.int64(2560), i2_0 * T.int64(128) + i2_1 * T.int64(8) + ax2)
T.reads(var_NT_matmul_intermediate_pad_local[v0, v1, v2], linear_bias3[v2], lv2[v0, v1, v2])
T.writes(p_output0_intermediate[v0, v1, v2])
if v1 < n:
p_output0_intermediate[v0, v1, v2] = var_NT_matmul_intermediate_pad_local[v0, v1, v2] + linear_bias3[v2] + lv2[v0, v1, v2]
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/root/mlc-llm/mlc_llm/build.py", line 13, in <module>
main()
File "/root/mlc-llm/mlc_llm/build.py", line 10, in main
core.build_model_from_args(parsed_args)
File "/root/mlc-llm/mlc_llm/core.py", line 450, in build_model_from_args
build(mod, args)
File "/root/mlc-llm/mlc_llm/core.py", line 370, in build
mod_deploy = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod_deploy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
TypeError: Traceback (most recent call last):
5: TVMFuncCall
4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
3: tvm::transform::Pass::operator()(tvm::IRModule) const
2: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
1: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/ir/transform.py", line 307, in _pass_func
return inst.transform_module(mod, ctx)
File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 64, in transform_module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
sch = _apply_rules(func, target, self.rules, tunable=False)
File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 80, in _apply_rules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
space = rule.apply(func, target, tunable)
File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/dlight/gpu/gemv.py", line 165, in apply
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
if len(block_infos) == 1:
^^^^^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()
However, I cannot reproduce this when compiling Llama and RWKV. I'm using mlc-ai-nightly-cu116==0.12.dev1310
and mlc-chat-nightly-cu116==0.1.dev293
. It seems that the GEMV logic is modified in this version, probably related to #599 , or maybe relax_model/gpt_neox.py
need a refactor.
🐛 Bug
I encountered the following error when I was trying to add a custom relax_model and compile it.
In tvm.dlight.gpu.gemv.GEMV below,
block_infos
can be returned asNone
by functionnormalize_prim_func()
andtry_inline_contiguous_spatial()
. Then, usinglen(block_infos)
without checkingnot None
result in error. It may be something wrong on my end, but I am not clear how to locate it through the error information.To Reproduce
Steps to reproduce the behavior:
1. 1. 1.
Expected behavior
Environment
Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA):
CUDA
Operating system (e.g. Ubuntu/Windows/MacOS/...):
CentOS
Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...):
How you installed MLC-LLM (
conda
, source):conda
How you installed TVM-Unity (
pip
, source):pip
Python version (e.g. 3.10):
3.11
GPU driver version (if applicable):
CUDA/cuDNN version (if applicable):
TVM Unity Hash Tag (
python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
, applicable if you compile models):Any other relevant information:
Additional context