apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.58k stars 3.43k forks source link

[Relax][Bug] Cannot find PackedFunc tir_zeros #17176

Closed Cookiee235 closed 1 day ago

Cookiee235 commented 1 month ago

Actual behavior

Traceback (most recent call last):
  File "demo.py", line 25, in <module>
    vm = relax.VirtualMachine(ex, tvm.cpu())
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/runtime/relax_vm.py", line 97, in __init__
    self._setup_device(device, memory_cfg)
  File "/software/tvm/python/tvm/runtime/relax_vm.py", line 133, in _setup_device
    self.module["vm_initialization"](*init_args)
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  3: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_8relax_vm18VirtualMachineImpl
  2: tvm::runtime::relax_vm::VirtualMachineImpl::_Init(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::runtime::relax_vm::VirtualMachineImpl::Init(std::vector<DLDevice, std::allocator<DLDevice> > const&, std::vector<tvm::runtime::memory::AllocatorType, std::allocator<tvm::runtime::memory::AllocatorType> > const&)
  0: tvm::runtime::relax_vm::VirtualMachineImpl::InitFuncPool()
  File "/software/tvm/src/runtime/relax_vm/vm.cc", line 705
InternalError: Check failed: (func.defined()) is false: Error: Cannot find PackedFunc tir_zeros in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable

Environment

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"system_lib_prefix": "hello_"})
    @T.prim_func
    def tir_zeros(x: T.Buffer((2,), "float32")):
        x[0] = T.float32(0)

    @R.function
    def main() -> R.Tensor((2,), dtype="float32"):
        cls = Module
        gv0 = R.call_tir(cls.tir_zeros, R.tuple(), out_sinfo=R.Tensor((2,), dtype="float32"))
        return gv0

m = Module
m.show()
ex = relax.build(m, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())

cc @junrushao @Hzfengsy

Cookiee235 commented 1 month ago

@tqchen @Hzfengsy @junrushao The function tir_zeros is defined in the Module explicitly. Why the compilation crash and throw: Cannot find PackedFunc tir_zeros? I'm very confused. Look forward to your response. Many thanks!

Lunderberg commented 1 month ago

I was able to reproduce the error, and it looks like it's an incompatibility between the AttachGlobalSymbols pass and the relax.VMCodeGen step. When the "system_lib_prefix" is included, the "global_symbol" attribute produced by AttachGlobalSymbols doesn't match the name of the GlobalVar. As a result, the compiled TIR module provides a function named "hello_tir_zeros", but the Relax VM attempts to find a function named "tir_zeros".

This would be caught by the well-formed checker, which requires the "global_symbol" to exactly match the GlobalVar value if present. However, this is only validated for Relax functions, not for TIR functions (implementation here).

Lunderberg commented 1 month ago

To resolve this, and to prevent similar bugs from occurring in the future, I think we should make the following changes:

1) When AttachGlobalSymbol changes the name of a function, it must also replace the GlobalVar used to refer to that function. All occurrences of the old GlobalVar across the IRModule must then be replaced by the new GlobalVar. 2) The well-formed checker should validate that the GlobalVar and "global_symbol" match across all functions, not just relax functions.

Lunderberg commented 1 month ago

@Cookiee235 Can you test with #17202 applied? With this PR, I am able to successfully run your test case.

Cookiee235 commented 1 month ago

@Cookiee235 Can you test with #17202 applied? With this PR, I am able to successfully run your test case.

Hi @Lunderberg Thanks for your efforts. Following your advice, I pull the source code in your repository (i.e., Lunderberg:replace_global_var_on_rename), and build TVM from source code. However, it built failed. Part of the crash message is as follows

In file included from /software/tvm-lunder/src/runtime/builtin_fp16.cc:25:
/software/tvm-lunder/include/tvm/runtime/c_runtime_api.h:144:15: error: 'kDLROCMHost' was not declared in this scope; did you mean 'kDLROCM'?
  144 | static_assert(kDLROCMHost == 11, TVM_HARCODED_INTEGER_CHANGED_MSG);
      |               ^~~~~~~~~~~
      |               kDLROCM

...

make[2]: *** [CMakeFiles/tvm_libinfo_objs.dir/build.make:76: CMakeFiles/tvm_libinfo_objs.dir/src/support/libinfo.cc.o] Error 1
make[1]: *** [CMakeFiles/Makefile2:905: CMakeFiles/tvm_libinfo_objs.dir/all] Error 2
Lunderberg commented 1 month ago

Hmm. That looks like it isn't able to find any symbols from dlpack.h. Did you checkout the submodules? (git submodule update --init --recursive)

Cookiee235 commented 1 month ago

Hmm. That looks like it isn't able to find any symbols from dlpack.h. Did you checkout the submodules? (git submodule update --init --recursive)

So sorry, I forgot. Indeed, I also successfully ran this test under the PR https://github.com/apache/tvm/pull/17202 applied! @Lunderberg Thank you again for your PR!

jpf888 commented 1 month ago

@Lunderberg @Cookiee235 hi, When I try to use cudnn dispatch conv2d, I encounter the same problem as you during runtime. After replacing with the code changes you submitted, I still have the same problem. Perhaps our issues are similar but not exactly the same Could you please give me some suggestions?

1、Actual behavior: vm::runtime::Optional<tvm::runtime::Session> const&, bool&>(tvm::runtime::String&, tvm::runtime::String&, picojson::object_with_ordered_keys const&, DLDevice&, tvm::runtime::Optional<tvm::runtime::Session> const&, bool&) at mlc-llm/3rdparty/tvm/include/tvm/runtime/memory.h:196 5: tvm::runtime::ObjectPtr<mlc::llm::serve::ModelImpl> tvm::runtime::ObjAllocatorBase<tvm::runtime::SimpleObjAllocator>::make_object<mlc::llm::serve::ModelImpl, tvm::runtime::String&, tvm::runtime::String&, picojson::object_with_ordered_keys const&, DLDevice&, tvm::runtime::Optional<tvm::runtime::Session> const&, bool&>(tvm::runtime::String&, tvm::runtime::String&, picojson::object_with_ordered_keys const&, DLDevice&, tvm::runtime::Optional<tvm::runtime::Session> const&, bool&) at mlc-llm/3rdparty/tvm/include/tvm/runtime/memory.h:72 4: mlc::llm::serve::ModelImpl* tvm::runtime::SimpleObjAllocator::Handler<mlc::llm::serve::ModelImpl>::New<tvm::runtime::String&, tvm::runtime::String&, picojson::object_with_ordered_keys const&, DLDevice&, tvm::runtime::Optional<tvm::runtime::Session> const&, bool&>(tvm::runtime::SimpleObjAllocator*, tvm::runtime::String&, tvm::runtime::String&, picojson::object_with_ordered_keys const&, DLDevice&, tvm::runtime::Optional<tvm::runtime::Session> const&, bool&) at mlc-llm/3rdparty/tvm/include/tvm/runtime/memory.h:122 3: mlc::llm::serve::ModelImpl::ModelImpl(tvm::runtime::String, tvm::runtime::String, picojson::object_with_ordered_keys, DLDevice, tvm::runtime::Optional<tvm::runtime::Session> const&, bool) at /mlc-llm/cpp/serve/model.cc:66 2: mlc::llm::serve::FunctionTable::Init(tvm::runtime::String, DLDevice, picojson::object_with_ordered_keys, tvm::runtime::Optional<tvm::runtime::Session>) at /mlc-llm/cpp/serve/function_table.cc:133 1: tvm::runtime::relax_vm::VirtualMachineImpl::_Init(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 0: tvm::runtime::relax_vm::VirtualMachineImpl::InitFuncPool() File "/mlc-llm/3rdparty/tvm/src/runtime/relax_vm/vm.cc", line 707 InternalError: Check failed: (func.defined()) is false: Error: Cannot find PackedFunc fused_relax_nn_conv2d_cudnn in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable

2、But when I run the test case of TVM's cudnn conv2d separately, it works normally, and I can find 'fused_relax_nn_conv2d_cudnn'.

3、I printed out the mod before and after the cudnn conv2d patten processing, and the names are all the same with tvm cudnn test, such as fused_relax_nn_conv2d_cudnn."

Lunderberg commented 1 month ago

@jpf888 This sounds like a similar bug, but would depend on how the dispatch is implemented. When the pattern-matching is applied, is the call to the fused kernel generated as module.fused_function(args), or as R.call_extern("fused_function", args)? The fix applied in #17202 would only apply to calls that are known to be within the same IRModule (the module.fused_function(args) version), and not to calls that may be defined outside of the IRModule (the R.call_extern version).

jpf888 commented 1 month ago

@Lunderberg 1、When I apply dispatch in mlcllm, this issue occurs,and When the pattern-matching is applied, is the call to the fused kernel generated as “R.call_dps_packed("fused_relax_nn_conv2d_cudnn", args) of the class Module“ ,

2、When in the TVM test case, it works fine, log : before pattern

from tvm.script import ir as I
from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight: R.Tensor((32, 3, 3, 16), dtype="float16")) -> R.Tensor((16, 32, 32, 32), dtype="float16"):
        with R.dataflow():
            lv: R.Tensor((16, 32, 32, 32), dtype="float16") = R.nn.conv2d(data, weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC", out_dtype="float16")
            R.output(lv)
        return lv

after pattern:

    from tvm.script import ir as I
    from tvm.script import relax as R
    @I.ir_module
    class Module:
        @R.function
        def fused_relax_nn_conv2d_cudnn(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight: R.Tensor((32, 3, 3, 16), dtype="float16")) -> R.Tensor((16, 32, 32, 32), dtype="float16"):
        R.func_attr({"Codegen": "cudnn"})
        # from tvm.script import relax as R

        @R.function
        def local_func(data_1: R.Tensor((16, 32, 32, 16), dtype="float16"), weight_1: R.Tensor((32, 3, 3, 16), dtype="float16")) -> R.Tensor((16, 32, 32, 32), dtype="float16"):
            R.func_attr({"Composite": "cudnn.conv2d.nhwc_ohwi"})
            with R.dataflow():
                gv: R.Tensor((16, 32, 32, 32), dtype="float16") = R.nn.conv2d(data_1, weight_1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC", out_dtype="float16")
                R.output(gv)
            return gv

        output: R.Tensor((16, 32, 32, 32), dtype="float16") = local_func(data, weight)
        return output

        @R.function
        def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight: R.Tensor((32, 3, 3, 16), dtype="float16")) -> R.Tensor((16, 32, 32, 32), dtype="float16"):
            cls = Module
            with R.dataflow():
                gv: R.Tensor((16, 32, 32, 32), dtype="float16") = cls.fused_relax_nn_conv2d_cudnn(data, weight)
                R.output(gv)
            return gv  
jpf888 commented 1 month ago

mlc-llm cudnn dispatch

    import tvm
    import tvm.relax.backend.contrib.cudnn as _cudnn
    from tvm.relax.backend.contrib.cudnn import partition_for_cudnn
    from tvm import IRModule, relax
    from tvm.relax.backend import get_patterns_with_prefix

    @tvm.transform.module_pass(opt_level=0, name="CudnnDispatch")
    class CudnnDispatch:  # pylint: disable=too-few-public-methods,broad-exception-raised

        def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
            """IRModule-level transformation"""
            has_cudnn = tvm.get_global_func("relax.ext.cudnn", True)
            if not has_cudnn:
                raise Exception("CUDNN is not enabled.")
            patterns = get_patterns_with_prefix("cudnn")

            model_names = [
                gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function)
            ]
            # exclude single batch decode
            model_names = [name for name in model_names if "image_embed" in name]

            mod = tvm.transform.Sequential(
                [
                    relax.transform.FuseOpsByPattern(
                        patterns,
                        bind_constants=False,
                        annotate_codegen=True,
                        entry_functions=model_names,
                    ),
                    relax.transform.RunCodegen({}, entry_functions=model_names),
                ]
            )(mod)
            return mod
jpf888 commented 1 month ago

@Lunderberg I resolved the issue by adding multiple dispatch to the mlcllm compile pipeline. The problem consistently occurred with the first dispatch(cudnn), but leaving just one dispatch, any dispatch, it worked !