[X] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Motivation
Currently, the implementation for tree_map is to simply flatten the tree, map over leaves, and unflatten back to the original tree structure. We intentionally sort the dict keys to keep "equal inputs imply equal leaves and equal structure".
In [1]: import optree
In [2]: d1 = {'a': 1, 'b': 2}
In [3]: d2 = {'b': 2, 'a': 1}
In [4]: d1 == d2
Out[4]: True
In [5]: optree.tree_flatten(d1)
Out[5]: ([1, 2], PyTreeSpec({'a': *, 'b': *}))
In [6]: optree.tree_flatten(d2)
Out[6]: ([1, 2], PyTreeSpec({'a': *, 'b': *}))
In [7]: from collections import OrderedDict
In [8]: od1 = OrderedDict([('a', 1), ('b', 2)])
In [9]: od2 = OrderedDict([('b', 2), ('a', 1)])
In [10]: od1 == od2
Out[10]: False
In [11]: optree.tree_flatten(od1)
Out[11]: ([1, 2], PyTreeSpec(OrderedDict([('a', *), ('b', *)])))
In [12]: optree.tree_flatten(od2)
Out[12]: ([2, 1], PyTreeSpec(OrderedDict([('b', *), ('a', *)])))
However, this may be too strict because this behavior is not preserved when you iterate over the keys.
In [13]: d1.keys() == d2.keys()
Out[13]: True
In [14]: list(d1.keys()) == list(d2.keys())
Out[14]: False
In [15]: od1.keys() == od2.keys()
Out[15]: True
In [16]: list(od1.keys()) == list(od2.keys())
Out[16]: False
Also, we have seen many use cases to use tree_map to process function inputs, like:
Add an internal flatten function to preserve the dict order. Since the intermediate variable treespec is never returned, it would be fine to leave PyTreeSpec methods, e.g., __eq__, __ne__, entries unchanged.
Required prerequisites
Motivation
Currently, the implementation for
tree_map
is to simply flatten the tree, map over leaves, and unflatten back to the original tree structure. We intentionally sort thedict
keys to keep "equal inputs imply equal leaves and equal structure".However, this may be too strict because this behavior is not preserved when you iterate over the keys.
Also, we have seen many use cases to use
tree_map
to process function inputs, like:The
tree_map
function will sort the order of the keyword arguments:This may cause potential bugs. Because many modules now heavily rely on that
builtins.dict
is guaranteed to be insertion ordered since Python 3.7.Solution
Add an internal flatten function to preserve the
dict
order. Since the intermediate variabletreespec
is never returned, it would be fine to leavePyTreeSpec
methods, e.g.,__eq__
,__ne__
,entries
unchanged.Alternatives
No response
Additional context
No response