tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[ARCH] NestedMsg util functions update #390

Closed Ubospica closed 1 year ago

Ubospica commented 1 year ago

Added two util functions in nested_msg.h:

SiriusNEO commented 1 year ago

This two utils are useful in Tuple-aware AD (which is still experimental in mlc): https://github.com/mlc-ai/relax/pull/103.

An example of application: in a Tuple-aware AD, we want to maintain the adjoint of Var. So we have a map of Map<Var, Expr>. But note that it is tuple-aware, a var may be bound to a Tuple expr. And when we update the adjoint in AD, relax.add doesn't works for Tuple. Therefore we need a tuple-aware Add. Moreover, to initialize the adjoint, we need a function to build a "Zeros Tuple" according to a TupleStructInfo. All these could be done elegantly by NestedMsg.

In short, we let the adjoint type be NestedMsg<Expr> and the adjoint map be Map<Var, NestedMsg<Expr>>. Then the implementation can be: (format: application <-- nested util)

For more details, please refer to the above PR. cc @tqchen @spectrometerHBH @MasterJH5574