Open slyubomirsky opened 2 years ago
out of curiosity: what is the expected output here? Is there a difference between the following:
y = relax.print(closure)
y = relax.print(closure())
Calling the closure would return the int. I was curious to see how closures were represented internally, so I'm not sure what the expected output necessarily was but we should be able to generate code for it.
Great exploration! For this specific case, the inner closure
function inside main
is not a closure, since it does not have free variables.
IRModule after lambda lifting (lambda lifting pass lifts all local functions no matter they are closure or not):
@tvm.script.ir_module
class Module:
@R.function
def main() -> Tuple():
# block 0
closure = lifted_func_0
y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
return y
@R.function
def lifted_func_0() -> Tensor(None, "int32", ndim = 0):
return 1
The closure variable is bound to a GlobalVar(lifted_func_0). We need to fix the codegen to handle binding a GlobalVar to a variable in VMCodegen (currently codegen only supports inline the GlobalVar in CallNode without var binding).
A closure case:
@tvm.script.ir_module
class PrintClosure2:
@R.function
def main(x: Tensor((2, 3), "float32")):
@R.function
def closure():
return x
y = relax.print(closure)
return y
After lambda lifting:
@tvm.script.ir_module
class Module:
@R.function
def main(x: Tensor((2, 3), "float32")) -> Tuple():
# block 0
closure: Object = relax.make_closure(lifted_func_0, (x,))
y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
return y
@R.function
def lifted_func_0(x1: Tensor((2, 3), "float32")) -> Tensor(None, "float32", ndim = 2):
return x1
The closure is of ObjectType
, and it's defined here: https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h#L43. This IRModule can be compiled and run.
Great. Also, on a very pedantic point, I would say that a function that does not capture anything should still be represented at run time as a closure, hence my using that term :)
I think our codegen should treat the cases uniformly, too (i.e., compile a reference to a global func into a closure with no captured variables), or at least make that the convention for passing functions to packed funcs. Later compilation passes could get rid of the closure wrapper if it's never used.
For printing closures, we don't need to display anything other than that it's a closure (most functional languages treat them as completely opaque).
The following program results in a VM code generation error:
The error is as follows:
It looks like there's a missing case in VM code generation