mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.14k stars 1.57k forks source link

[Question] Running TVM Dlight low-level optimizations ERROR #2661

Closed ponytaill closed 3 months ago

ponytaill commented 3 months ago

❓ General Questions

Hi, I am deploying my own quantization methods in MLC-LLM, but get errors about running TVM Dlight low-level optimizations. Here is my QuantizedLinear forward code: def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name out_shape = x.shape[:-1] + [self.out_features, ] flattened_shape = (tir.IntImm("int64", -1), tir.IntImm("int64", x.shape[-1])) x = nn.op.reshape(x, shape=flattened_shape).astype("float16")

    quant_x, s1 = self.config.dynamic_quant(x)
    w = nn.op.tensor_expr_op(  # pylint: disable=invalid-name
        lambda weight: self.config._dequantize(  # pylint: disable=protected-access
            weight,
        ),
        name_hint="_dequantize",
        args=[self.B],
    )

    temp = nn.op.matmul(quant_x, w, out_dtype="int32")
    s = nn.op.matmul(s1, self.s_channel / 16, out_dtype="float32")

    out = nn.tensor_expr_op(  # pylint: disable=invalid-name
        lambda temp, s: self.config.elementwise_multiply(  # pylint: disable=protected-access
            temp,
            s
        ),
        name_hint="elementwise_multiply",
        args=[temp, s],
    ).astype("float16")

    out = nn.op.reshape(out, shape=out_shape)
    out = out + self.bias if self.bias is not None else out
    return out

The error happened in out = nn.tensor_expr_op( # pylint: disable=invalid-name lambda temp, s: self.config.elementwise_multiply( # pylint: disable=protected-access temp, s ), name_hint="elementwise_multiply", args=[temp, s], ).astype("float16")

And self.config.elementwise_multiply is the elementwise_multiply function: def elementwise_multiply(self, A: te.Tensor, B: te.Tensor) -> te.Tensor: return te.compute( A.shape,
lambda i, j: A[i, j] * B[i, j], # elementwise multiplication operation name='result' )

When I compile, I get the following error:

[2024-07-15 17:48:22] INFO pipeline.py:52: Running TVM Relax graph-level optimizations [2024-07-15 17:50:15] INFO pipeline.py:52: Lowering to TVM TIR kernels [2024-07-15 17:50:35] INFO pipeline.py:52: Running TVM TIR-level optimizations [2024-07-15 17:51:08] INFO pipeline.py:52: Running TVM Dlight low-level optimizations Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/main.py", line 64, in main() File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/main.py", line 33, in main cli.main(sys.argv[2:]) File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/cli/compile.py", line 129, in main compile( File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 229, in compile _compile(args, model_config) File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 174, in _compile args.build_func( File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/support/auto_target.py", line 284, in build relax.build( File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 335, in build mod = pipeline(mod) ^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/ir/transform.py", line 238, in call return _ffi_transform_api.RunPass(self, mod) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.call File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3 File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/compiler_pass/pipeline.py", line 182, in _pipeline mod = seq(mod) ^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/ir/transform.py", line 238, in call return _ffi_transform_api.RunPass(self, mod) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.call File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3 File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/ir/transform.py", line 307, in _pass_func return inst.transform_module(mod, ctx) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 71, in transform_module sch = _apply_rules(func, target, self.rules, tunable=False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/dlight/base/transform.py", line 87, in _apply_rules space = rule.apply(func, target, tunable) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/dlight/gpu/reduction.py", line 88, in apply is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( ^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/tvm/dlight/gpu/reduction.py", line 162, in _normalize assert r_loops ^^^^^^^ AssertionError Traceback (most recent call last): File "/root/autodl-tmp/mlc-llm/1.py", line 5, in engine = MLCEngine(model) ^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine.py", line 1477, in init super().init( File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine_base.py", line 589, in init ) = _process_model_args(models, device, engine_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine_base.py", line 170, in _process_model_args model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine_base.py", line 170, in model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models] ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine_base.py", line 163, in _convert_model_info model_lib = jit.jit( ^^^^^^^^ File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/interface/jit.py", line 164, in jit _run_jit( File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/interface/jit.py", line 124, in _run_jit raise RuntimeError("Cannot find compilation output, compilation failed") RuntimeError: Cannot find compilation output, compilation failed Exception ignored in: <function MLCEngineBase.del at 0x7f01964c2a20> Traceback (most recent call last): File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine_base.py", line 653, in del self.terminate() File "/root/miniconda3/envs/mlc/lib/python3.11/site-packages/mlc_llm/serve/engine_base.py", line 660, in terminate self._ffi["exit_background_loop"]() ^^^^^^^^^ AttributeError: 'MLCEngine' object has no attribute '_ffi'

By the way, I tried to calculate separately the result like: def elementwise_multiply(self, A: te.Tensor, s1: te.Tensor , s2: te.Tensor) -> te.Tensor: out = te.compute(A.shape, lambda i, j: A[i, j] s1[i, 0], name='out') return te.compute(out.shape, lambda i, j: out[i, j] s2[0, j], name='result')

This time I did not get error but the model's Temporary buffer suddenly became very large. Weird. How can I solve it?

ponytaill commented 3 months ago

In a weird way, it works by calculating separately the result.

loredunk commented 2 months ago

how to fix it?