apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.58k stars 3.43k forks source link

[Bug] [Relax] cannot remove the hint_on_device #17205

Closed MellowArtisan closed 9 hours ago

MellowArtisan commented 1 month ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/7bug_assert.py", line 25, in <module>
    tvm.ir.assert_structural_equal(mod_seq, mod)  # assert failed
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/base.py", line 256, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  3: _ZN3tvm20SEqualHandlerDefault5EqualERKNS_
  2: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  1: tvm::SEqualHandlerDefault::Impl::RunTasks()
  0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  File "/software/tvm/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = y
                                                                       ^
            R.output(lv0)
        return lv0
and rhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
                                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            R.output(lv0)
        return lv0

Environment

TVM: 0.17.dev0

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})

    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
            R.output(lv0)
        return lv0

mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod)
mod = relax.transform.RealizeVDevice()(mod)
mod_seq.show()
mod.show()  # cannot remove the 'hint_on_device'
tvm.ir.assert_structural_equal(mod_seq, mod)  # assert failed

cc @junrushao

MellowArtisan commented 1 month ago

Hi all, I found the tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod) can remove the "R.hint_on_device()", however, using relax.transform.RealizeVDevice()(mod) directly cannot remove it.

@Lunderberg @tqchen Why do the different usages for the same pass (i.e., RealizeVDevice) give different optimization results?

Lunderberg commented 1 month ago

Hmm. It looks like it's even weirder. It looks like only the first use of RealizeVDevice produces the correct output. The first execution correctly removes the R.hint_on_device, but the second execution does not.

relax.transform.RealizeVDevice()(Module).show(name='First')
relax.transform.RealizeVDevice()(Module).show(name='Second')
Lunderberg commented 1 month ago

Aha! The problem is that HintOnDeviceRemover (the first step of RealizeVDevice) is mutating the relax expression in-place, which is not legal. As a result, expressions that are in the input Module are being mutated to have different StructInfo. The second time that RealizeVDevice is applied, its input has been mutated to already include vdevice annotations.

## Running these commands

Module["foo"].show(name="Before")
relax.transform.RealizeVDevice()(Module)
Module["foo"].show(name="OrigAfter")

## Produces the following output

@R.function
def Before(
    x: R.Tensor((2, 3), dtype="float32"),
    y: R.Tensor((2, 3), dtype="float32"),
    z: R.Tensor((2, 3), dtype="float32"),
) -> R.Tensor((2, 3), dtype="float32"):
    R.func_attr({"global_symbol": "foo"})
    with R.dataflow():
        lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
        R.output(lv0)
    return lv0

@R.function
def OrigAfter(
    x: R.Tensor((2, 3), dtype="float32"),
    y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:-1"),
    z: R.Tensor((2, 3), dtype="float32"),
) -> R.Tensor((2, 3), dtype="float32"):
    R.func_attr({"global_symbol": "foo"})
    with R.dataflow():
        lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:-1") = R.hint_on_device(
            y, R.device(dev_type=1, dev_id=0)
        )
        R.output(lv0)
    return lv0

The input module should never be modified when running any IRModule transform, so this definitely narrows the bug down to the RealizeVDevice implementation.

Lunderberg commented 1 month ago

@Cookiee235 Can you verify the fix implemented in #17213? It removes the in-place mutation from RealizeVDevice, and should resolve this issue.

Edit: Whoops, meant @MellowArtisan . With multiple issues/PRs in-flight, I got them mixed up.

Cookiee235 commented 1 month ago

@Cookiee235 Can you verify the fix implemented in #17213? It removes the in-place mutation from RealizeVDevice, and should resolve this issue.

@Lunderberg Yes! This PR fixed the bug. Thanks for your efforts!