metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
136 stars 6 forks source link

feat: preserve dict key order in the output of `tree_unflatten`, `tree_map`, and `tree_map_with_path` #46

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

Description

Describe your changes in detail.

Add a new option in _C.flatten and _C.flatten_with_path(enabled globally), which will save the original key order in PyTreeSpec and use it to preserve key order during unflattering.

Now, the output of tree_map and tree_map_with_path will guarantee the same key order as the input dictionaries.

Before: output dict from tree_unflatten is in sorted order. The output key order changes even if mapped with an identity function.

>>> optree.tree_map(lambda x: x, {'b': 2, 'a': 1})
{'a': 1, 'b': 2}

After: output dict from tree_unflatten is consistent with the input dict.

>>> optree.tree_map(lambda x: x, {'b': 2, 'a': 1})
{'b': 2, 'a': 1}

Note that tree_map still maps the leaves in sorted key order (the same order as tree_flatten and tree_leaves). This PR only changes the behavior for tree_unflatten.

>>> leaves = [] 
...
... def add_leaves(x):
...     leaves.append(x)
...     return x
...
>>> optree.tree_map(add_leaves, {'b': 2, 'a': 1})
{'b': 2, 'a': 1}
>>> leaves
[1, 2]

Also, If the users manually maintain the treespec themselves. For example, 1) flatten the tree, 2) hold the reference to the treespec and do something with the leaves, 3) unflatten back the results into a tree with treespec. The key order in the resulting tree is still sorted. This PR only affects functions tree_map and tree_map_with_path.

Motivation and Context

Why is this change required? What problem does it solve? If it fixes an open issue, please link to the issue here. You can use the syntax close #15213 if this solves the issue #15213

Resolves #45

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

Checklist

Go over all the following points, and put an x in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help!

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (9e150ce) 100.00% compared to head (e09a28b) 100.00%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #46 +/- ## ========================================= Coverage 100.00% 100.00% ========================================= Files 4 4 Lines 424 424 ========================================= Hits 424 424 ``` | Flag | Coverage Δ | | |---|---|---| | unittests | `100.00% <100.00%> (ø)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#carryforward-flags-in-the-pull-request-comment) to find out more. | [Impacted Files](https://codecov.io/gh/metaopt/optree/pull/46?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt) | Coverage Δ | | |---|---|---| | [optree/ops.py](https://codecov.io/gh/metaopt/optree/pull/46?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt#diff-b3B0cmVlL29wcy5weQ==) | `100.00% <100.00%> (ø)` | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=metaopt)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

XuehaiPan commented 1 year ago

Got around ~25% performance drop for tree-map on builtins.dict.

XuehaiPan commented 1 year ago

now the performance downgrade on tree_unflatten is intolerable

Code robustness is more important than performance. The tree_map function does not preserve dict key order with the identity function. It's not intuitive for normal users who do not read the documentation carefully.

This PR adds extra operations in tree_flatten and tree_unflatten.

It would introduce some performance regression. But the tree operations sometimes are not the performance bottleneck in the whole pipeline. For example, you reduce a tree operation from 0.1ms to 0.08ms (20%), but your environment simulation takes 1ms. The total performance gain is minor (1.1ms vs. 1.08ms (2%)).

Benjamin-eecs commented 1 year ago

It would introduce some performance regression. But the tree operations sometimes are not the performance bottleneck in the whole pipeline. For example, you reduce a tree operation from 0.1ms to 0.08ms (20%), but your environment simulation takes 1ms. The total performance gain is minor (1.1ms vs. 1.08ms (2%)).

Thanks for your reply! I will try to run the benchmark test with the whl in this pr on envpool these two days, I will get back to you when the results are out, please wait until the benchmark test is ready.