metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
136 stars 6 forks source link

[Feature Request] Allow tree-map with mixed inputs of ordered and unordered dictionaries #27

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

Required prerequisites

Motivation

The tree-map function uses a treespec.flatten_up_to to deal with the multi-input scenario. The flatten_up_to method requires the incoming pytree to have the exact same type and metadata as the treespec. For convenience, users may want to map OrderedDict with dict, or vice versa. The tree_map function will raise a error:

In [1]: from optree import *

In [2]: from collections import OrderedDict

In [3]: d = {'a': 1, 'b': 2}

In [4]: od = OrderedDict(d)

In [5]: d
Out[5]: {'a': 1, 'b': 2}

In [6]: od
Out[6]: OrderedDict([('a', 1), ('b', 2)])

In [7]: tree_map(lambda x, y: x + y, d, od)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[7], line 1
----> 1 tree_map(lambda x, y: x + y, d, od)

File ~/Projects/optree/optree/ops.py:460, in tree_map(func, tree, is_leaf, none_is_leaf, namespace, *rests)
    417 """Map a multi-input function over pytree args to produce a new pytree.
    418 
    419 See also :func:`tree_map_`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`.
   (...)
    457     is the tuple of values at corresponding nodes in ``rests``.
    458 """
    459 leaves, treespec = tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
--> 460 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
    461 flat_results = map(func, *flat_args)
    462 return treespec.unflatten(flat_results)

File ~/Projects/optree/optree/ops.py:460, in <listcomp>(.0)
    417 """Map a multi-input function over pytree args to produce a new pytree.
    418 
    419 See also :func:`tree_map_`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`.
   (...)
    457     is the tuple of values at corresponding nodes in ``rests``.
    458 """
    459 leaves, treespec = tree_flatten(tree, is_leaf, none_is_leaf=none_is_leaf, namespace=namespace)
--> 460 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
    461 flat_results = map(func, *flat_args)
    462 return treespec.unflatten(flat_results)

ValueError: Expected dict, got OrderedDict([('a', 1), ('b', 2)]).

NOTE: this feature is not backward compatible.

Solution

Reimplement the PyTreeSpec::FlattenUpTo method. Use set equality and ignore the key orders.

Alternatives

No response

Additional context

No response

XuehaiPan commented 1 year ago

After some benchmarking (3-way multi-treemap in benchmark.py), it got ~45% performance regression with the set-based approach and ~25% with the sort-based approach.

XuehaiPan commented 1 year ago

Other than sorting the keys, we can check the key equality by iterating over the keys and calling PyDict_Contains.