tlc-pack / relax

Apache License 2.0
194 stars 59 forks source link

[Bug][VM] Cannot pass a closure to a function call #220

Open slyubomirsky opened 1 year ago

slyubomirsky commented 1 year ago

The following program results in a VM code generation error:

import tvm
import tvm.script
from tvm import relax
from tvm.script import relax as R

@tvm.script.ir_module
class PrintClosure:
    @R.function
    def main():
    @R.function
        def closure():
        return relax.const(1)
    y = relax.print(closure)
        return y

mod = PrintClosure
mod = relax.transform.LambdaLift()(mod)
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)  # error happens on this line
vm = relax.VirtualMachine(ex, tvm.cpu())

ret = vm["main"]()

The error is as follows:

  5: tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  4: tvm::relax::relax_vm::CodeGen(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)
  3: tvm::relax::relax_vm::VMCodeGen::CodeGen(tvm::IRModule)
  2: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::FunctionNode const*)
  1: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::SeqExprNode const*)
  0: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg (tvm::RelayExpr const&)>::VisitExprDefault_(tvm::runtime::Object const*)
  File "[...]/relax/include/tvm/relax/expr_functor.h", line 114
TVMError: Do not have a default for GlobalVar

It looks like there's a missing case in VM code generation

psrivas2 commented 1 year ago

out of curiosity: what is the expected output here? Is there a difference between the following:

  1. y = relax.print(closure)
  2. y = relax.print(closure())
slyubomirsky commented 1 year ago

Calling the closure would return the int. I was curious to see how closures were represented internally, so I'm not sure what the expected output necessarily was but we should be able to generate code for it.

YuchenJin commented 1 year ago

Great exploration! For this specific case, the inner closure function inside main is not a closure, since it does not have free variables.

IRModule after lambda lifting (lambda lifting pass lifts all local functions no matter they are closure or not):

@tvm.script.ir_module
class Module:
    @R.function
    def main() -> Tuple():
        # block 0
        closure = lifted_func_0
        y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
        return y

    @R.function
    def lifted_func_0() -> Tensor(None, "int32", ndim = 0):
        return 1

The closure variable is bound to a GlobalVar(lifted_func_0). We need to fix the codegen to handle binding a GlobalVar to a variable in VMCodegen (currently codegen only supports inline the GlobalVar in CallNode without var binding).

A closure case:

@tvm.script.ir_module
class PrintClosure2:
    @R.function
    def main(x: Tensor((2, 3), "float32")):
        @R.function
        def closure():
            return x
        y = relax.print(closure)
        return y

After lambda lifting:

@tvm.script.ir_module
class Module:
    @R.function
    def main(x: Tensor((2, 3), "float32")) -> Tuple():
        # block 0
        closure: Object = relax.make_closure(lifted_func_0, (x,))
        y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
        return y

    @R.function
    def lifted_func_0(x1: Tensor((2, 3), "float32")) -> Tensor(None, "float32", ndim = 2):
        return x1

The closure is of ObjectType, and it's defined here: https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h#L43. This IRModule can be compiled and run.

slyubomirsky commented 1 year ago

Great. Also, on a very pedantic point, I would say that a function that does not capture anything should still be represented at run time as a closure, hence my using that term :)

I think our codegen should treat the cases uniformly, too (i.e., compile a reference to a global func into a closure with no captured variables), or at least make that the convention for passing functions to packed funcs. Later compilation passes could get rid of the closure wrapper if it's never used.

For printing closures, we don't need to display anything other than that it's a closure (most functional languages treat them as completely opaque).