metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
146 stars 7 forks source link

feat(ops): add tree flatten and tree map functions with extra paths #11

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

Description

Describe your changes in detail.

Motivation and Context

Why is this change required? What problem does it solve? If it fixes an open issue, please link to the issue here. You can use the syntax close #15213 if this solves the issue #15213

Add functions tree_flatten_with_path, tree_paths, and tree_map_with_path.

The path is a tuple of entries (indices or keys) to the corresponding leaf. The depth of the leaf can be derivated by depth = len(path).

paths, leaves, treespec = optree.tree_flatten_with_path(tree)
paths = optree.tree_paths(tree)

def fn(p, x, *xs):
    ...

mapped_tree = optree.tree_map_with_path(fn, tree, *rests)

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

Checklist

Go over all the following points, and put an x in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help!

codecov-commenter commented 1 year ago

Codecov Report

Base: 92.65% // Head: 92.93% // Increases project coverage by +0.28% :tada:

Coverage data is based on head (f233607) compared to base (582f355). Patch coverage: 90.00% of modified lines in pull request are covered.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #11 +/- ## ========================================== + Coverage 92.65% 92.93% +0.28% ========================================== Files 4 4 Lines 245 269 +24 ========================================== + Hits 227 250 +23 - Misses 18 19 +1 ``` | Flag | Coverage Δ | | |---|---|---| | unittests | `92.93% <90.00%> (+0.28%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#carryforward-flags-in-the-pull-request-comment) to find out more. | [Impacted Files](https://codecov.io/gh/metaopt/optree/pull/11?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt) | Coverage Δ | | |---|---|---| | [optree/\_\_init\_\_.py](https://codecov.io/gh/metaopt/optree/pull/11/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#diff-b3B0cmVlL19faW5pdF9fLnB5) | `100.00% <ø> (ø)` | | | [optree/registry.py](https://codecov.io/gh/metaopt/optree/pull/11/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#diff-b3B0cmVlL3JlZ2lzdHJ5LnB5) | `95.50% <ø> (ø)` | | | [optree/ops.py](https://codecov.io/gh/metaopt/optree/pull/11/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#diff-b3B0cmVlL29wcy5weQ==) | `90.62% <90.00%> (+0.91%)` | :arrow_up: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.