Closed ASEM000 closed 1 year ago
Patch coverage: 94.11
% and project coverage change: -0.04
:warning:
Comparison is base (
d513182
) 99.60% compared to head (28ba642
) 99.56%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
Description
This PR deprecates the auto unwrapping of frozen instances inside
treeclass.
1) Frozen instances are objects wrapped by
pytc.freeze
that appliesFrozenWrapper
on it. 2) Autounwrapping is done at__getattribute__
level. In essence, if an instance variable is wrapped bypytc.freeze
, then when fetching the variable usingself.{attr_name}
, the instance variable will be unwrapped.The motivation for this PR is to unify the training method used by
PyTreeClass
and other pytrees .Comparing with splitting/partitioning method:
tree=[1,2.,3.]
then we can split thetree
into a dynamic tree and static tree usingjax.tree_map
. Then we can pass the dynamic and static parts over a jax function boundary then merging/combining the parts of the tree inside the function.tree = [1, 2., 3.] static_tree = [1, None, None] # -> non-differentiable tree ( 1 is not inexact type) dynamic_tree= [None, 2., 3.] # -> differentiable tree ( 2. ,3. are inexact type)
@jax.jit @jax.grad def f(dynamic_tree, static_tree, ...):
merge/combine the tree
tree = jax.tree_map(lambda x,y: x or y, dynamic_tree, static_tree, is_leaf=lambda x: x is None) ...
Freezing with
.at
3)
pytc.freeze
plays nicely withat
intreeclass
wrapped classesRelationship with other pytrees
4)
pytc.freeze
can be used with other pytrees, this is becausepytc.freeze
appliesFrozenWrapper
on an object.FrozenWrapper
is simply a registered pytree that yields no leaves when flattened.