metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
146 stars 7 forks source link

[Feature Request] Preserve `dict` key order in the output of `tree_unflatten`, `tree_map`, and `tree_map_with_path` #45

Closed XuehaiPan closed 1 year ago

XuehaiPan commented 1 year ago

Required prerequisites

Motivation

Currently, the implementation for tree_map is to simply flatten the tree, map over leaves, and unflatten back to the original tree structure. We intentionally sort the dict keys to keep "equal inputs imply equal leaves and equal structure".

In [1]: import optree

In [2]: d1 = {'a': 1, 'b': 2}

In [3]: d2 = {'b': 2, 'a': 1}

In [4]: d1 == d2
Out[4]: True

In [5]: optree.tree_flatten(d1)
Out[5]: ([1, 2], PyTreeSpec({'a': *, 'b': *}))

In [6]: optree.tree_flatten(d2)
Out[6]: ([1, 2], PyTreeSpec({'a': *, 'b': *}))

In [7]: from collections import OrderedDict

In [8]: od1 = OrderedDict([('a', 1), ('b', 2)])

In [9]: od2 = OrderedDict([('b', 2), ('a', 1)])

In [10]: od1 == od2
Out[10]: False

In [11]: optree.tree_flatten(od1)
Out[11]: ([1, 2], PyTreeSpec(OrderedDict([('a', *), ('b', *)])))

In [12]: optree.tree_flatten(od2)
Out[12]: ([2, 1], PyTreeSpec(OrderedDict([('b', *), ('a', *)])))

However, this may be too strict because this behavior is not preserved when you iterate over the keys.

In [13]: d1.keys() == d2.keys()
Out[13]: True

In [14]: list(d1.keys()) == list(d2.keys())
Out[14]: False

In [15]: od1.keys() == od2.keys()
Out[15]: True

In [16]: list(od1.keys()) == list(od2.keys())
Out[16]: False

Also, we have seen many use cases to use tree_map to process function inputs, like:

def func(*args, **kwargs):
    (args, kwargs) = optree.tree_map(do_something, (args, kwargs))
    ...

The tree_map function will sort the order of the keyword arguments:

In [1]: import optree

In [2]: import torch

In [3]: args = (torch.zeros(()), torch.ones(()))

In [4]: kwargs = {'dtype': torch.float32, 'device': torch.device('cpu')}

In [5]: (args, kwargs)
Out[5]: ((tensor(0.), tensor(1.)), {'dtype': torch.float32, 'device': device(type='cpu')})

In [6]: optree.tree_map(lambda x: x, (args, kwargs))
Out[6]: ((tensor(0.), tensor(1.)), {'device': device(type='cpu'), 'dtype': torch.float32})

This may cause potential bugs. Because many modules now heavily rely on that builtins.dict is guaranteed to be insertion ordered since Python 3.7.

Solution

Add an internal flatten function to preserve the dict order. Since the intermediate variable treespec is never returned, it would be fine to leave PyTreeSpec methods, e.g., __eq__, __ne__, entries unchanged.

Alternatives

No response

Additional context

No response