Open Conchylicultor opened 4 years ago
Hmmm... Looks like ordereddict
and defaultdict
keep keys in order: https://github.com/google/jax/blob/bf041fbdb16cb1360d3b914ab44d8c3a799d566b/jax/tree_util.py#L247-L255
While standard dicts alphabetize their keys: https://github.com/google/jax/blob/c7aff1da06072db8fb074f09a8215615d607adc2/jaxlib/pytree.cc#L132-L134
That is surprising behavior: it would be nice to make this more consistent.
I agree, we should not sort dictionary keys.
To support Python 3.6 we need to sort keys for deterministic traversal order on equal but distinct dict objects, right?
That is, I think the OP may have been mistaken, according to the Python docs:
Changed in version 3.7: Dictionary order is guaranteed to be insertion order. This behavior was an implementation detail of CPython from 3.6.
To support Python 3.6 we need to sort keys for deterministic traversal order on equal but distinct dict objects, right?
This is still the case even in Python 3.7+. Dictionaries and dictionary keys preserve order, but compare equal if they have the same elements regardless of order:
In [14]: x = {1: 1, 2: 2}
In [15]: y = {2: 2, 1: 1}
In [16]: x
Out[16]: {1: 1, 2: 2}
In [17]: y
Out[17]: {2: 2, 1: 1}
In [18]: x == y
Out[18]: True
In [19]: x.keys() == y.keys()
Out[19]: True
In [20]: x.keys()
Out[20]: dict_keys([1, 2])
In [21]: y.keys()
Out[21]: dict_keys([2, 1])
I believe Jax should behave like dm-tree
where flattening any dict sort the keys, but packing dict restore the original dict key order.
import tree
x = {'z': 'z', 'a': 'a'}
print(tree.flatten(x)) # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1])) # Key order restored: {'z': 1, 'a': 0}
This allow all dict to have the same flattened representation, to be mixed together:
import jax
d0 = {'z': 'z', 'a': 'a'}
d1 = collections.defaultdict(int, d0)
assert jax.tree_leaves(d0) == jax.tree_leaves(d1) # AssertionError: Oups ['z', 'a'] != ['a', 'z']
@Conchylicultor good point! That sounds plausible.
@mattjj Would it be a good idea to register dict
as node similar to what is done for ordereddict
and defaultdict
?
register_pytree_node(
dict,
lambda x: (tuple(x.values()), {key: None for key in x}.keys()),
lambda keys, values: dict(safe_zip(keys, values))
)
Of course things related to kDict
should also be removed from jaxlib/pytree.{h,cc}
Adding my support for maintaining dict key ordering over flattening operations.
A related question: for a dictionary d
, does tuple(d.values())
internally use tree_flatten
? Because that operation also does not maintain key ordering when building the tuple.
Now that support Python 3.6 has been dropped, we can probably revisit this.
I believe this is an important bug to be fixed. Have we figured out why this is happening (or may I know where is the relevant source code of the C++ implementation so I can try digging into)? I think the behavior should be constant regardless of python 3.6, but it'd be great to revisit this one.
@wookayin - it's happening because that's how tree flattening of dicts is implemented. The line where the sort is taking place is here: https://github.com/tensorflow/tensorflow/blob/eb8425f115e5a93274f709cdfaf254798f9aa4c7/tensorflow/compiler/xla/python/pytree.cc#L167
The problem is, "fixing" this is not as easy as just removing that sort. There are many parts of JAX that rely on equality comparisons of the flattened representation of dicts, and if you preserve insertion order in flattening, then d1 == d2
no longer implies that tree_flatten(d1) == tree_flatten(d2)
, which has deep and subtle implications in the implementation of JAX transforms throughout the package.
For that reason, it's not clear to me whether this should be considered a bug, or just the way that flattening works in JAX (and it's why nobody as of yet has been eager to attempt making this change).
@jakevdp I don't understand your argument. This was already resolved in https://github.com/google/jax/issues/4085#issuecomment-675322841
Flattening would still be sorted, so if d1 == d2
, then tree_flatten(d1) == tree_flatten(d2)
, irrespectively of the d1 and d2 key order.
However, the key order would be restored during packing:
x = {'z': 'z', 'a': 'a'}
print(tree.flatten(x)) # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1])) # Key order restored: {'z': 1, 'a': 0}
So all dict (OrderedDict,...) would have the same flattened representation, but would still preserve the keys order when unflattened.
then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2)
This is exactly the problem with the current Jax implementation:
d0 = {'z': 'z', 'a': 'a'}
d1 = collections.OrderedDict(d0)
assert d0 == d1 # Works
assert jax.tree_leaves(d0) == jax.tree_leaves(d1) # << AssertionError: Oups ['z', 'a'] != ['a', 'z']
# By comparison, DM `tree` / tf.nest works as expected:
assert tree.flatten(d0) == tree.flatten(d1) # Works: ['a', 'z'] == ['a', 'z']
That makes sense, thanks. I'd missed that comment from a few years ago.
It seems that there's broad agreement here that this should be fixed – we just need someone to take on the project.
Could we add a flag (e.g. a global variable) to let the user decide whether to sort the keys or not? For example:
jax.tree_util.dict_key_sorted(True) # default behavior
jax.tree_util.dict_key_sorted(False)
In this issue, all the keys are strings, which are sortable. There is another issue about dict
key sorting #11871. For a general PyTree, the keys are not always comparable:
tree = {1: '1', 'a': 'a'}
sorted(tree) # <- TypeError: '<' not supported between instances of 'str' and 'int'
Could we add a flag (e.g. a global variable) to let the user decide whether to sort the keys or not? For example:(breaks referential transparency)jax.tree_util.dict_key_sorted(True) # default behavior jax.tree_util.dict_key_sorted(False)
The problem is, "fixing" this is not as easy as just removing that sort. There are many parts of JAX that rely on equality comparisons of the flattened representation of dicts, and if you preserve insertion order in flattening, then
d1 == d2
no longer implies thattree_flatten(d1) == tree_flatten(d2)
, which has deep and subtle implications in the implementation of JAX transforms throughout the package.For that reason, it's not clear to me whether this should be considered a bug, or just the way that flattening works in JAX (and it's why nobody as of yet has been eager to attempt making this change).
I agree, referential transparency should be a key feature for the pytree utilities: equal inputs implies equal outputs. However, the current implementation always sorts the key order and returns a new sorted dict after unfattening. Nowadays, many Python code rely on the builtins.dict
is guaranteed insertion order. This may cause potential bugs that many people do not aware this behavior in JAX pytree (sorted keys after tree_unflatten
).
d = {'b': 2, 'a': 1}
# Map with the identity function changes the key order
out = jax.tree_util.tree_map(lambda x: x, d) # => {'a': 1, 'b': 2}
d == out # => True
list(d) == list(out) # => False ['b', 'a'] != ['a', 'b']
For example, use tree_map
to process kwarges
(PEP 468 – Preserving the order of **kwargs in a function.):
def func(*args, **kwargs):
args, kwargs = jax.tree_util.tree_map(do_something, (args, kwargs)) # changes key order in kwargs
...
In [1]: import jax
In [2]: from typing import NamedTuple
In [3]: class Ints(NamedTuple):
...: foo: int
...: bar: int
...:
In [4]: Ints(1, 2)
Out[4]: Ints(foo=1, bar=2)
In [5]: Ints(1, 2).foo
Out[5]: 1
In [6]: Ints.__annotations__
Out[6]: {'foo': <class 'int'>, 'bar': <class 'int'>}
In [7]: Floats = NamedTuple('Floats', **jax.tree_util.tree_map(lambda ann: float, Ints.__annotations__))
In [8]: Floats(1.0, 2.0)
Out[8]: Floats(bar=1.0, foo=2.0)
In [9]: Floats(1.0, 2.0).foo
Out[9]: 2.0
One solution is to store the input dict keys in insertion order in Node
during flatten
, and update the PyTreeDef.unflatten
method to respect the key order while reconstructing the output pytree.
leaves, treedef = jax.tree_util.tree_flatten({'b': 2, 'a': 1})
leaves # [1, 2]
treedef # PyTreeDef({'a': *, 'b': *})
treedef.unflatten([11, 22]) # {'b': 22, 'a': 11} # respect original key order
Ref:
Commenting to add that I just encountered this behavior and I find it quite annoying.
If I was to implement this, I'd use as treedef a dictionary with the same keys but filled with None values. This way the implementation would completely piggyback on Python and remain consistent under all circumstances.
Any update on this? I have a similar issue with jax.jit:
$ python3 -c "import jax; print(jax.jit(lambda: {'b': None, 'a': None})())"
{'a': None, 'b': None}
With Python 3.6+,
dict
are guarantee to keep insertion order, similarly toOrderedDict
.Both deepmind
tree
andtf.nest
keep dict order, butjax.tree_util
does not.The fact that
dict
andOrderedDict
behave differently whendict
guarantee to keep insertion order feel inconsistent.