apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.63k stars 3.45k forks source link

[Bug] RemoveUnusedOutputs give unexpected results #17247

Closed Cookiee235 closed 2 weeks ago

Cookiee235 commented 1 month ago

Hi all, The pass RemoveUnusedOutputs seems to give an unexpected optimized result. Due to the lack of detailed documentation about this API (e.g., relax.transform.RemoveUnusedOutputs), I cannot confirm if the optimization result is wrong.

In addition, another bug is about the API tvm.ir.assert_structural_equal, for the totally same mod, this API judge the structure of them as unequal. It was triggered by IRs with the string "nan".

Actual behavior

## Output IRs after the RemoveUnusedOutputs
@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
            R.output(res)
        return res
----------------------------------------------------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/assert_structure.py", line 66, in <module>
    tvm.ir.assert_structural_equal(mod, mod)
  File "/software/tvm-lunder/python/tvm/ir/base.py", line 256, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: _ZN3tvm7runtime13PackedFuncObj
  4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  3: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  2: tvm::SEqualHandlerDefault::Impl::RunTasks()
  1: tvm::SEqualHandlerDefault::DispatchSEqualReduce(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  File "/software/tvm-lunder/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("main")].body.blocks[0].bindings[0].value.fields[0].value.value:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
                                                                                                                                                  ^^^^^
            R.output(res)
        return res
and rhs at <root>.functions[I.GlobalVar("main")].body.blocks[0].bindings[0].value.fields[0].value.value:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
                                                                                                                                                  ^^^^^
            R.output(res)
        return res

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def ones(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 1

    @T.prim_func(private=True)
    def zeros(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @T.prim_func(private=True)
    def zeros1(T_full: T.Buffer((T.int64(32), T.int64(32)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @R.function(private=True)
    def func() -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        cls = Module
        A = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        B = R.call_tir(cls.ones, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        C = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((32, 32), dtype="int32"))
        return (A, B, C)

    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        R.func_attr({"num_input": 2})
        cls = Module
        with R.dataflow():
            res: R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")) = cls.func()
            R.output(res)
        return res

mod = Module
mod.show()

mod = relax.transform.RemoveUnusedOutputs()(mod)
mod.show()  # is this irs correct?
tvm.ir.assert_structural_equal(mod, mod)  # not equal! why?

cc @Lunderberg @junrushao

Lunderberg commented 1 month ago

Looks like the test case can be made even simpler, and isn't limited to RemoveUnusedOutputs. The root cause is that StructuralEqual compared TIR floats by checking abs(lhs-rhs) < 1e9, which always evaluates to false for NaN values.

@T.prim_func(private=True)
def func_1():
    return T.float32("nan")

@T.prim_func(private=True)
def func_2():
    return T.float32("nan")

tvm.ir.assert_structural_equal(func_1, func_2)

I've implemented #17249 which should fix this issue, by having StructuralEqual and StructuralHash have special handling to compare NaN values.

Cookiee235 commented 1 month ago

@Lunderberg Fixing for fixing the wrong implementation about assert_structural_equal. I have another question. For the given IRs in my script, an odd IRs was obtained after using the RemoveUnusedOutputs optimization. It seems function func should not be removed. Can you help me check if this is a bug?


Output IRs after the RemoveUnusedOutputs


@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
            R.output(res)
        return res```
Lunderberg commented 1 month ago

Ooh, I had missed that part, and thought there was a nan inside the original model. Thank you for calling my attention to it.

Lunderberg commented 1 month ago

The introduction of nan is a bug in RemoveUnusedOutputs. When determining which outputs of a callee are used, it only collected usages in TupleGetItem(out_tuple, index). If the tuple is used in a context that doesn't access a specific element, such as returning from a function, then the usage is skipped.

The nan values are intended as placeholders, as a dummy value for indexing. If a callee produces (A,B,C), but B is never used, then the callee would be updated to produce (A,C), and callsites would be updated to replace res = callee() with new_output = callee(); res = (new_output[0], NaN, new_output[1]). The intermediate tuple would then be deconstructed with CanonicalizeBindings. Since nothing ever accessed res[1], the NaN value at that location would drop out altogether.

So, if a function called a subroutine that produces a tuple, then immediately returned that tuple, the usage would fail to be collected, and the tuple elements would be erroneously replaced with NaN. This should be resolved with https://github.com/apache/tvm/pull/17253.

Lunderberg commented 1 month ago

Re-opening, as the auto-close from #17249 wasn't correct. This issue still requires #17253 to land in order to be resolved.