Closed Ubospica closed 1 year ago
This two utils are useful in Tuple-aware AD (which is still experimental in mlc): https://github.com/mlc-ai/relax/pull/103.
An example of application: in a Tuple-aware AD, we want to maintain the adjoint of Var. So we have a map of Map<Var, Expr>
. But note that it is tuple-aware, a var may be bound to a Tuple expr. And when we update the adjoint in AD, relax.add
doesn't works for Tuple. Therefore we need a tuple-aware Add. Moreover, to initialize the adjoint, we need a function to build a "Zeros Tuple" according to a TupleStructInfo. All these could be done elegantly
by NestedMsg.
In short, we let the adjoint type be NestedMsg<Expr>
and the adjoint map be Map<Var, NestedMsg<Expr>>
. Then the implementation can be: (format: application <-- nested util)
For more details, please refer to the above PR. cc @tqchen @spectrometerHBH @MasterJH5574
Added two util functions in
nested_msg.h
:MapToNestedMsgBySInfo
: map Expr to NestedMsg according to StructInfo. If the Expr is not a TupleNode but its StructInfo is TupleStructInfo (for example, Expr is a Var assigned to by a Tuple), it will add TupleGetItem nodesNestedMsgToExpr
: map NestedMsg back to Expr. Tuple nodes will be created.