ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link

Deprecate autounwrapping in favaor of Freeze/unfreeze pattern #52

Closed ASEM000 closed 1 year ago

ASEM000 commented 1 year ago

Description

This PR deprecates the auto unwrapping of frozen instances inside treeclass.

1) Frozen instances are objects wrapped by pytc.freeze that applies FrozenWrapper on it. 2) Autounwrapping is done at __getattribute__ level. In essence, if an instance variable is wrapped bypytc.freeze, then when fetching the variable using self.{attr_name}, the instance variable will be unwrapped.

The motivation for this PR is to unify the training method used by PyTreeClass and other pytrees .

Comparing with splitting/partitioning method:

tree = [1, 2., 3.] static_tree = [1, None, None] # -> non-differentiable tree ( 1 is not inexact type) dynamic_tree= [None, 2., 3.] # -> differentiable tree ( 2. ,3. are inexact type)

@jax.jit @jax.grad def f(dynamic_tree, static_tree, ...):

merge/combine the tree

tree = jax.tree_map(lambda x,y: x or y, dynamic_tree, static_tree, is_leaf=lambda x: x is None) ...


However,   Using `freeze/unfreeze` does not require splitting/partitioning the tree; instead, the `freeze` wraps the leaf with a registered Wrapper that yields no leaves (i.e. shields the wrapped subtree from any changes). 
This pattern is preferable to `split/partition.` :
1) As it avoids creating extra tree mostly of Nones nodes.
2) avoid changing the function signature to allow for the extra tree as an argument.
in the previous example, we had to make our function accept two args for the tree instead of a single argument to accommodate for the two parts of the tree.
```python

tree = [1, 2., 3.]
# no extra tree
tree = jax.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, tree)

@jax.jit
@jax.grad(tree, ... ):
def f(tree, ...):
  # unfreeze
  tree = jax.tree_map(lambda x: pytc.unfreeze(x), tree, is_leaf=pytc.is_frozen)
  ...

Freezing with .at

3) pytc.freeze plays nicely with at in treeclass wrapped classes


@pytc.treeclass
class Tree:
  a:int = 1
  b:float = 2.
  c:float = 3.

tree = Tree()

# Directly apply freeze using .at with attribute name/node index
tree = tree.at['a'].apply(pytc.freeze)

# or apply freeze with a boolean mask to nondifferentiable nodes.
tree = tree.at[pytc.bcmap(pytc.is_nondiff)(tree)].apply(pytc.freeze)

Relationship with other pytrees

4) pytc.freeze can be used with other pytrees, this is because pytc.freeze applies FrozenWrapper on an object. FrozenWrapper is simply a registered pytree that yields no leaves when flattened.

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 94.11% and project coverage change: -0.04 :warning:

Comparison is base (d513182) 99.60% compared to head (28ba642) 99.56%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #52 +/- ## ========================================== - Coverage 99.60% 99.56% -0.04% ========================================== Files 14 14 Lines 2775 2767 -8 ========================================== - Hits 2764 2755 -9 - Misses 11 12 +1 ``` | [Impacted Files](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem) | Coverage Δ | | |---|---|---| | [pytreeclass/\_src/tree\_trace.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX3RyYWNlLnB5) | `100.00% <ø> (ø)` | | | [pytreeclass/\_src/tree\_decorator.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX2RlY29yYXRvci5weQ==) | `99.53% <80.00%> (-0.47%)` | :arrow_down: | | [pytreeclass/\_\_init\_\_.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX19pbml0X18ucHk=) | `100.00% <100.00%> (ø)` | | | [pytreeclass/\_src/tree\_freeze.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX2ZyZWV6ZS5weQ==) | `100.00% <100.00%> (ø)` | | | [pytreeclass/\_src/tree\_pprint.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX3BwcmludC5weQ==) | `100.00% <100.00%> (ø)` | | | [tests/test\_tree\_freeze.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlX2ZyZWV6ZS5weQ==) | `100.00% <100.00%> (ø)` | | | [tests/test\_tree\_pprint.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlX3BwcmludC5weQ==) | `98.43% <100.00%> (ø)` | | | [tests/test\_treeclass.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlY2xhc3MucHk=) | `99.15% <100.00%> (ø)` | | | [tests/test\_under\_jit.py](https://codecov.io/gh/ASEM000/PyTreeClass/pull/52?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF91bmRlcl9qaXQucHk=) | `100.00% <100.00%> (ø)` | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.