ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link

v0.4 #73

Closed ASEM000 closed 1 year ago

ASEM000 commented 1 year ago

1) Add tree_mask, tree_unmask to freeze/unfreeze tree leaves based on a callable/boolean pytree mask. defaults to masking non-inexact types by frozen wrapper.

Example: Pass non-jax types through jax transformation without error.

# pass non-differentiable values to `jax.grad`
import pytreeclass as pytc
import jax
@jax.grad
def square(tree):
    tree = pytc.tree_unmask(tree)
    return tree[0]**2
tree = (1., 2)  # contains a non-differentiable node
square(pytc.tree_mask(tree))
# (Array(2., dtype=float32, weak_type=True), #2)

2) User-provided re.Pattern is used to match keys with regex pattern instead of using RegexKey

Example:

import pytreeclass as pytc
import re 

tree = {"l1":1, "l2":2, "b":3}
tree = pytc.AtIndexer(tree)
tree.at[re.compile("l.*")].get()
# {'b': None, 'l1': 1, 'l2': 2}

3) Support extending match keys by adding abstract base class BaseKey. check docstring for example

4) Support multi-index by any acceptable form. e.g. boolean pytree, key, int, or BaseKey instance

Example:

import pytreeclass as pytc
tree = {"l1":1, "l2":2, "b":3}
tree = pytc.AtIndexer(tree)
tree.at["l1","l2"].get()
# {'b': None, 'l1': 1, 'l2': 2}

5) add scan to AtIndexer to carry a state while applying a function.

Example:

import pytreeclass as pytc
def scan_func(leaf, state):
    # increase the state by 1 for each function call
    return leaf**2, state+1

tree = {"l1": 1, "l2": 2, "b": 3}
tree = pytc.AtIndexer(tree)
tree, state = tree.at["l1", "l2"].scan(scan_func, 0)
state
# 2
tree
# {'b': 3, 'l1': 1, 'l2': 4}
review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 91.06% and project coverage change: -0.29 :warning:

Comparison is base (124f5e6) 97.54% compared to head (32c56bf) 97.25%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #73 +/- ## ========================================== - Coverage 97.54% 97.25% -0.29% ========================================== Files 13 15 +2 Lines 2403 2478 +75 ========================================== + Hits 2344 2410 +66 - Misses 59 68 +9 ``` | [Impacted Files](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem) | Coverage Δ | | |---|---|---| | [tests/test\_nn.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF9ubi5weQ==) | `98.65% <ø> (ø)` | | | [tests/test\_tree\_operator.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlX29wZXJhdG9yLnB5) | `100.00% <ø> (ø)` | | | [tests/test\_tree\_pprint.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlX3BwcmludC5weQ==) | `96.69% <ø> (ø)` | | | [tests/test\_tree\_viz\_util.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlX3Zpel91dGlsLnB5) | `96.00% <ø> (ø)` | | | [tests/test\_treeclass.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlY2xhc3MucHk=) | `98.77% <ø> (ø)` | | | [tests/test\_under\_jit.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF91bmRlcl9qaXQucHk=) | `100.00% <ø> (ø)` | | | [pytreeclass/\_src/tree\_pprint.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX3BwcmludC5weQ==) | `91.97% <78.57%> (-0.52%)` | :arrow_down: | | [pytreeclass/\_src/tree\_mask.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX21hc2sucHk=) | `86.36% <86.36%> (ø)` | | | [pytreeclass/\_src/tree\_index.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX2luZGV4LnB5) | `93.40% <93.40%> (ø)` | | | [pytreeclass/\_\_init\_\_.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX19pbml0X18ucHk=) | `100.00% <100.00%> (ø)` | | | ... and [5 more](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/73?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem) | |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.