apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.61k stars 3.44k forks source link

[Bug] AssertionError in the LazyTransformParams #17269

Open Cookiee235 opened 1 month ago

Cookiee235 commented 1 month ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/assert_lazy.py", line 52, in <module>
    mod = relax.transform.LazyTransformParams()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/transform/lazy_transform_params.py", line 396, in transform_module
    func = lazy_mutator.transform(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/transform/lazy_transform_params.py", line 151, in transform
    forward_collector.visit_expr(func)
  File "/software/tvm-lunder/python/tvm/relax/expr_functor.py", line 346, in visit_expr
    return _ffi_api.PyExprVisitorVisitExpr(self, expr)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/meta_schedule/utils.py", line 76, in method
    return getattr(inst, name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/transform/lazy_transform_params.py", line 59, in visit_var_binding_
    assert isinstance(binding.value, relax.Tuple)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(C: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), T.int64(16)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(C[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = C[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def multiply(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_multiply: T.Buffer((T.int64(16), T.int64(16)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * T.float32(2)

    @R.function
    def transform_params(A: R.Tensor((16, 16), dtype="float32"), B: R.Tensor((16, 16), dtype="float32")) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")):
        cls = Module
        C = R.call_tir(cls.multiply, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
        D = R.call_tir(cls.add, (C, B), out_sinfo=R.Tensor((16, 16), dtype="float32"))
        para0: R.Tensor((16, 16), dtype="float32") = B
        para1: R.Tensor((16, 16), dtype="float32") = B
        res: R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")) = cls.transform_params_2(para0, para1)
        return res

    @R.function
    def transform_params_2(A: R.Tensor((16, 16), dtype="float32"), B: R.Tensor((16, 16), dtype="float32")) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")):
        cls = Module
        C = R.call_tir(cls.multiply, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32"))
        D = R.call_tir(cls.add, (C, B), out_sinfo=R.Tensor((16, 16), dtype="float32"))
        return (D, D)

mod = Module
mod = relax.transform.LazyTransformParams()(mod)  # crash here

cc @Lunderberg @junrushao

Lunderberg commented 1 month ago

I assume there's a typo and that cls.transform_params_7 is supposed to be cls.transform_params_2. With that, I can reproduce your error.

It looks like this is a limitation in the LazyTransformParams, that it expects a tuple of outputs to be produced within the function, rather than being a return value from a subroutine. There's a couple of options on how this can be worked around:

Cookiee235 commented 4 weeks ago

@Lunderberg Thanks for your investigation. Such information help deeply undestand the usage of different transforms. Due to the incomplete documentation of TVM, understand the usage of each transform based on the source code has some difficulty. Your explanation help me a lot! Thanks again.