mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
18.58k stars 1.5k forks source link

[Bug] Compilation Failure on Model orca_mini_3b #1105

Closed aadarsh-ram closed 9 months ago

aadarsh-ram commented 10 months ago

🐛 Bug

I set up mlc-llm and tried to compile Orca-Mini 3B. I am getting the following error:

(myenv) aadarsh@AAD-HPLAP:~/src/mlc-llm$ python3 -m mlc_llm.build --hf-path pankajmathur/orca_mini_3b --target vulkan --quantization q4f16_1
Weights exist at dist/models/orca_mini_3b, skipping download.
Using path "dist/models/orca_mini_3b" for model "orca_mini_3b"
Target configured: vulkan -keys=vulkan,gpu -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=256 -supports_16bit_buffer=1 -supports_8bit_buffer=1 -supports_float16=1 -supports_float32=1 -supports_int16=1 -supports_int32=1 -supports_int8=1 -supports_storage_buffer_storage_class=1 -thread_warp_size=1
Failed to detect local GPU, falling back to CPU as a target
Automatically using target for weight quantization: llvm -keys=cpu
Get old param:   0%|                                                                       | 0/161 [00:00<?, ?tensors/sStart computing and quantizing weights... This may take a while.                            | 0/267 [00:00<?, ?tensors/s]
Get old param:   1%|▍                                                              | 1/161 [00:00<01:29,  1.78tensors/s]
Set new param:   0%|▏                                                              | 1/267 [00:00<02:28,  1.79tensors/s]
Get old param:  98%|███████████████████████████████████████████████████████████▊ | 158/161 [03:30<00:04,  1.42s/tensorsFinish computing and quantizing weights.█████████████████████████████████████████▊| 266/267 [03:35<00:00,  1.07tensors/s]
Total param size: 1.7960131168365479 GB
Start storing to cache dist/orca_mini_3b-q4f16_1/params
[0267/0267] saving param_266
All finished, 67 total shards committed, record saved to dist/orca_mini_3b-q4f16_1/params/ndarray-cache.json
Finish exporting chat config to dist/orca_mini_3b-q4f16_1/params/mlc-chat-config.json
[16:32:56] /home/aadarsh/src/tvm-unity/include/tvm/topi/transform.h:1126: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound
[16:32:56] /home/aadarsh/src/tvm-unity/include/tvm/topi/transform.h:1126: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound
Save a cached module to dist/orca_mini_3b-q4f16_1/mod_cache_before_build.pkl.
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/aadarsh/src/mlc-llm/mlc_llm/build.py", line 46, in <module>
    main()
  File "/home/aadarsh/src/mlc-llm/mlc_llm/build.py", line 42, in main
    core.build_model_from_args(parsed_args)
  File "/home/aadarsh/src/mlc-llm/mlc_llm/core.py", line 686, in build_model_from_args
    build(mod, args)
  File "/home/aadarsh/src/mlc-llm/mlc_llm/core.py", line 551, in build
    mod_deploy = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
  File "/home/aadarsh/src/tvm-unity/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/aadarsh/src/tvm-unity/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/aadarsh/src/tvm-unity/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
  File "/home/aadarsh/src/tvm-unity/python/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/home/aadarsh/src/tvm-unity/python/tvm/dlight/base/transform.py", line 64, in transform_module
    sch = _apply_rules(func, target, self.rules, tunable=False)
  File "/home/aadarsh/src/tvm-unity/python/tvm/dlight/base/transform.py", line 80, in _apply_rules
    space = rule.apply(func, target, tunable)
  File "/home/aadarsh/src/tvm-unity/python/tvm/dlight/gpu/reduction.py", line 102, in apply
    self._sch_inner_spatial(sch, target, block, c_factor, epilogue)
  File "/home/aadarsh/src/tvm-unity/python/tvm/dlight/gpu/reduction.py", line 225, in _sch_inner_spatial
    sch.bind(s, "threadIdx.x")
  File "/home/aadarsh/src/tvm-unity/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
    return func(*args, **kwargs)
  File "/home/aadarsh/src/tvm-unity/python/tvm/tir/schedule/schedule.py", line 1172, in bind
    _ffi_api.ScheduleBind(self, loop, thread_axis)  # type: ignore # pylint: disable=no-member
  File "/home/aadarsh/src/tvm-unity/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/aadarsh/src/tvm-unity/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
  1: tvm::tir::TracedScheduleNode::Bind(tvm::tir::LoopRV const&, tvm::runtime::String const&)
        at /home/aadarsh/src/tvm-unity/src/tir/schedule/traced_schedule.cc:311
  0: tvm::tir::ConcreteScheduleNode::Bind(tvm::tir::LoopRV const&, tvm::runtime::String const&)
        at /home/aadarsh/src/tvm-unity/src/tir/schedule/concrete_schedule.cc:562
ScheduleError: An error occurred in the schedule primitive 'bind'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
        B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16")
        matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")
        with T.block("root"):
            T.reads()
            T.writes()
            matmul_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16", scope="local")
            for ax0_ax1_fused_0 in T.thread_binding(T.int64(200), thread="blockIdx.x"):
                for ax0_ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                    for ax2_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                        with T.block("matmul_rf_init"):
                            vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1)
                            v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 * T.int64(16) + ax0_ax1_fused_1) // T.int64(100))
                            v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 * T.int64(16) + ax0_ax1_fused_1) % T.int64(100))
                            T.reads()
                            T.writes(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1])
                            matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = T.float16(0)
                        for ax2_fused_0 in range((n + T.int64(15)) // T.int64(16)):
                            for u in range(1):
                                with T.block("matmul_rf_update"):
                                    vax2_fused_1 = T.axis.spatial(T.int64(16), ax2_fused_1)
                                    v0 = T.axis.spatial(T.int64(32), (ax0_ax1_fused_0 * T.int64(16) + ax0_ax1_fused_1) // T.int64(100))
                                    v1 = T.axis.spatial(T.int64(100), (ax0_ax1_fused_0 * T.int64(16) + ax0_ax1_fused_1) % T.int64(100))
                                    vax2_fused_0 = T.axis.reduce((n + T.int64(15)) // T.int64(16), ax2_fused_0)
                                    T.where(ax2_fused_0 * T.int64(16) + ax2_fused_1 < n)
                                    T.reads(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1], A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1], B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1])
                                    T.writes(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1])
                                    matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] = matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1] + A[T.int64(0), v0, T.int64(0), vax2_fused_0 * T.int64(16) + vax2_fused_1] * B[T.int64(0), v0, vax2_fused_0 * T.int64(16) + vax2_fused_1, v1]
                # tir.For#0
                for ax1_ax2_fused in range((ax0_ax1_fused_0 * T.int64(16) % T.int64(100) + T.int64(15)) // T.int64(100) * T.int64(100) + T.int64(100)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax0 in range(T.int64(16)):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        # tir.Block#1
                        with T.block("matmul"):
                        ^^^^^^^^^^^^^^^^^^^^^^^
                            vax2_fused_1 = T.axis.reduce(T.int64(16), ax0)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused_0 * T.int64(16) // T.int64(100) + ax1_ax2_fused // T.int64(100))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v1 = T.axis.spatial(T.int64(100), ax1_ax2_fused % T.int64(100))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.reads(matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.writes(matmul[T.int64(0), v0, T.int64(0), v1])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            with T.init():
                            ^^^^^^^^^^^^^^
                                matmul[T.int64(0), v0, T.int64(0), v1] = T.float16(0)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            matmul[T.int64(0), v0, T.int64(0), v1] = matmul[T.int64(0), v0, T.int64(0), v1] + matmul_rf_local[vax2_fused_1, T.int64(0), v0, T.int64(0), v1]
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The queried subtree root tir.For#0 in SRef tree does not have compact dataflow, because its child block tir.Block#1 on SRef tree is neither a local complete block nor a local reduction block.
It violates condition #1 as a local complete block.
Definition of a local complete block:
1) All block vars are data parallel
2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
3) No overlap between the buffers the block reads and writes
It violates condition #2 as a local reduction block.
Definition of a reduction block:
1) The block has the `init` statement
2) All the block bindings are quasi-affine expressions
3) All block vars are either data parallel block vars or reduction block vars
4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree
5) The reduction block vars are not used to index the output buffers

Please let me know further steps to rectify this error.

Environment

junrushao commented 10 months ago

@vinx13 is working on a fix

junrushao commented 10 months ago

Just wanted to follow up with you, @vinx13, is this issue fixed already?

vinx13 commented 10 months ago

I'm working on a fix, we need to update dlight rules to choose appropriate tile size

MasterJH5574 commented 9 months ago

Hi @aadarsh-ram, we suppose the issue has been already fixed. Would you like to checkout the latest TVM Unity and try again? You are welcome to reopen the issue if there are further problems.

aadarsh-ram commented 9 months ago

@MasterJH5574 I'll try the compilation process again and let you know. Thank you!

aadarsh-ram commented 9 months ago

Hey @MasterJH5574, I'm getting the following error. I built TVM-Unity as per the instructions here.

(myenv) aadarsh@AAD-HPLAP:~/src/mlc-llm$ python3 -m mlc_llm.build --hf-path pankajmathur/orca_mini_3b --target vulkan --quantization q4f16_1
Weights exist at dist/models/orca_mini_3b, skipping download.
Using path "dist/models/orca_mini_3b" for model "orca_mini_3b"
Target configured: vulkan -keys=vulkan,gpu -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=256 -supports_16bit_buffer=1 -supports_8bit_buffer=1 -supports_float16=1 -supports_float32=1 -supports_int16=1 -supports_int32=1 -supports_int8=1 -supports_storage_buffer_storage_class=1 -thread_warp_size=1
Failed to detect local GPU, falling back to CPU as a target
Automatically using target for weight quantization: llvm -keys=cpu
Get old param:   0%|                                                                       | 0/161 [00:00<?, ?tensors/sStart computing and quantizing weights... This may take a while.                            | 0/267 [00:00<?, ?tensors/s]
Get old param:   1%|▍                                                              | 1/161 [00:00<02:16,  1.17tensors/s]
Set new param:   0%|▏                                                              | 1/267 [00:00<03:46,  1.17tensors/s]
Get old param:  98%|███████████████████████████████████████████████████████████▊ | 158/161 [03:47<00:04,  1.55s/tensorsFinish computing and quantizing weights.█████████████████████████████████████████▊| 266/267 [03:52<00:00,  1.01tensors/s]
Total param size: 1.7960131168365479 GB
Start storing to cache dist/orca_mini_3b-q4f16_1/params
[0267/0267] saving param_266
All finished, 67 total shards committed, record saved to dist/orca_mini_3b-q4f16_1/params/ndarray-cache.json
Finish exporting chat config to dist/orca_mini_3b-q4f16_1/params/mlc-chat-config.json
Save a cached module to dist/orca_mini_3b-q4f16_1/mod_cache_before_build.pkl.
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/aadarsh/src/mlc-llm/mlc_llm/build.py", line 46, in <module>
    main()
  File "/home/aadarsh/src/mlc-llm/mlc_llm/build.py", line 42, in main
    core.build_model_from_args(parsed_args)
  File "/home/aadarsh/src/mlc-llm/mlc_llm/core.py", line 686, in build_model_from_args
    build(mod, args)
  File "/home/aadarsh/src/mlc-llm/mlc_llm/core.py", line 578, in build
    ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib)
  File "/home/aadarsh/src/tvm-unity/python/tvm/relax/vm_build.py", line 343, in build
    return _vmlink(builder, target, tir_mod, ext_libs, params, system_lib=system_lib)
  File "/home/aadarsh/src/tvm-unity/python/tvm/relax/vm_build.py", line 242, in _vmlink
    lib = tvm.build(
  File "/home/aadarsh/src/tvm-unity/python/tvm/driver/build_module.py", line 281, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "/home/aadarsh/src/tvm-unity/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/aadarsh/src/tvm-unity/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/home/aadarsh/src/tvm-unity/src/driver/driver_api.cc", line 527, in operator()
    return TIRToRuntime(inputs_arg, host_target);
  File "/home/aadarsh/src/tvm-unity/src/driver/driver_api.cc", line 510, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
    device_modules.push_back(codegen::Build(device_mod, it.first));
  File "/home/aadarsh/src/tvm-unity/src/target/codegen.cc", line 72, in tvm::codegen::Build(tvm::IRModule, tvm::Target)
    ICHECK(bf != nullptr) << build_f_name << " is not enabled";
tvm.error.InternalError: Traceback (most recent call last):
  2: operator()
        at /home/aadarsh/src/tvm-unity/src/driver/driver_api.cc:527
  1: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
        at /home/aadarsh/src/tvm-unity/src/driver/driver_api.cc:510
  0: tvm::codegen::Build(tvm::IRModule, tvm::Target)
        at /home/aadarsh/src/tvm-unity/src/target/codegen.cc:72
  File "/home/aadarsh/src/tvm-unity/src/target/codegen.cc", line 72
InternalError: Check failed: (bf != nullptr) is false: target.build.vulkan is not enabled
MasterJH5574 commented 9 months ago

Hi @aadarsh-ram, you mentioned you built TVM Unity from source. Could you help check and make sure set(USE_VULKAN ON) is set in build/config.cmake? Following the instructions here https://llm.mlc.ai/docs/install/tvm.html#option-2-build-from-source, the commands would be

# Now at path/to/tvm
cd build
cp ../cmake/config.cmake .

# controls default compilation flags
echo "set(CMAKE_BUILD_TYPE RelWithDebInfo)" >> config.cmake
# LLVM is a must dependency
echo "set(USE_LLVM \"llvm-config --ignore-libllvm --link-static\")" >> config.cmake
echo "set(HIDE_PRIVATE_SYMBOLS ON)" >> config.cmake
echo "set(USE_VULKAN ON)" >> config.cmake

After finish the cmake configuration we can build TVM again. It is then expected to have Vulkan enabled.

MasterJH5574 commented 9 months ago

And here is an easy check in Python to validate if vulkan is enabled:

>>> import tvm
>>> tvm.vulkan()
vulkan(0)

If it displays vulkan(0) as output, then Vulkan is successfully enabled.

aadarsh-ram commented 9 months ago

Yep, perfect! It worked. Thank you for your help!

MasterJH5574 commented 9 months ago

Glad to hear it works!