jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.3k stars 2.78k forks source link

jax.tree_utils do not keep dict key order #4085

Open Conchylicultor opened 4 years ago

Conchylicultor commented 4 years ago

With Python 3.6+, dict are guarantee to keep insertion order, similarly to OrderedDict.

Both deepmind tree and tf.nest keep dict order, but jax.tree_util does not.

import tensorflow as tf
import tree
import jax

data = {'z': None, 'a': None}

print(tf.nest.map_structure(lambda _: None, data))  # {'z': None, 'a': None}
print(tree.map_structure(lambda _: None, data))     # {'z': None, 'a': None}
print(jax.tree_map(lambda _: None, data))           # {'a': None, 'z': None}  << Oups, keys order inverted

The fact that dict and OrderedDict behave differently when dict guarantee to keep insertion order feel inconsistent.

jakevdp commented 4 years ago

Hmmm... Looks like ordereddict and defaultdict keep keys in order: https://github.com/google/jax/blob/bf041fbdb16cb1360d3b914ab44d8c3a799d566b/jax/tree_util.py#L247-L255

While standard dicts alphabetize their keys: https://github.com/google/jax/blob/c7aff1da06072db8fb074f09a8215615d607adc2/jaxlib/pytree.cc#L132-L134

That is surprising behavior: it would be nice to make this more consistent.

shoyer commented 4 years ago

I agree, we should not sort dictionary keys.

mattjj commented 4 years ago

To support Python 3.6 we need to sort keys for deterministic traversal order on equal but distinct dict objects, right?

mattjj commented 4 years ago

That is, I think the OP may have been mistaken, according to the Python docs:

Changed in version 3.7: Dictionary order is guaranteed to be insertion order. This behavior was an implementation detail of CPython from 3.6.

shoyer commented 4 years ago

To support Python 3.6 we need to sort keys for deterministic traversal order on equal but distinct dict objects, right?

This is still the case even in Python 3.7+. Dictionaries and dictionary keys preserve order, but compare equal if they have the same elements regardless of order:

In [14]: x = {1: 1, 2: 2}

In [15]: y = {2: 2, 1: 1}

In [16]: x
Out[16]: {1: 1, 2: 2}

In [17]: y
Out[17]: {2: 2, 1: 1}

In [18]: x == y
Out[18]: True

In [19]: x.keys() == y.keys()
Out[19]: True

In [20]: x.keys()
Out[20]: dict_keys([1, 2])

In [21]: y.keys()
Out[21]: dict_keys([2, 1])
Conchylicultor commented 4 years ago

I believe Jax should behave like dm-tree where flattening any dict sort the keys, but packing dict restore the original dict key order.

import tree

x = {'z': 'z', 'a': 'a'}

print(tree.flatten(x))  # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1]))  # Key order restored: {'z': 1, 'a': 0}

This allow all dict to have the same flattened representation, to be mixed together:

import jax

d0 = {'z': 'z', 'a': 'a'}
d1 = collections.defaultdict(int, d0)

assert jax.tree_leaves(d0) == jax.tree_leaves(d1)  # AssertionError: Oups ['z', 'a'] != ['a', 'z']
mattjj commented 4 years ago

@Conchylicultor good point! That sounds plausible.

uniq10 commented 4 years ago

@mattjj Would it be a good idea to register dict as node similar to what is done for ordereddict and defaultdict?

register_pytree_node(
  dict,
  lambda x: (tuple(x.values()),  {key: None for key in x}.keys()),
  lambda keys, values: dict(safe_zip(keys, values))
)

Of course things related to kDict should also be removed from jaxlib/pytree.{h,cc}

DylanMuir commented 3 years ago

Adding my support for maintaining dict key ordering over flattening operations.

A related question: for a dictionary d, does tuple(d.values()) internally use tree_flatten? Because that operation also does not maintain key ordering when building the tuple.

jakevdp commented 2 years ago

Now that support Python 3.6 has been dropped, we can probably revisit this.

wookayin commented 2 years ago

I believe this is an important bug to be fixed. Have we figured out why this is happening (or may I know where is the relevant source code of the C++ implementation so I can try digging into)? I think the behavior should be constant regardless of python 3.6, but it'd be great to revisit this one.

jakevdp commented 2 years ago

@wookayin - it's happening because that's how tree flattening of dicts is implemented. The line where the sort is taking place is here: https://github.com/tensorflow/tensorflow/blob/eb8425f115e5a93274f709cdfaf254798f9aa4c7/tensorflow/compiler/xla/python/pytree.cc#L167

The problem is, "fixing" this is not as easy as just removing that sort. There are many parts of JAX that rely on equality comparisons of the flattened representation of dicts, and if you preserve insertion order in flattening, then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2), which has deep and subtle implications in the implementation of JAX transforms throughout the package.

For that reason, it's not clear to me whether this should be considered a bug, or just the way that flattening works in JAX (and it's why nobody as of yet has been eager to attempt making this change).

Conchylicultor commented 2 years ago

@jakevdp I don't understand your argument. This was already resolved in https://github.com/google/jax/issues/4085#issuecomment-675322841

Flattening would still be sorted, so if d1 == d2, then tree_flatten(d1) == tree_flatten(d2), irrespectively of the d1 and d2 key order.

However, the key order would be restored during packing:

x = {'z': 'z', 'a': 'a'}

print(tree.flatten(x))  # Keys sorted: ['a', 'z']
print(tree.unflatten_as(x, [0, 1]))  # Key order restored: {'z': 1, 'a': 0}

So all dict (OrderedDict,...) would have the same flattened representation, but would still preserve the keys order when unflattened.

then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2)

This is exactly the problem with the current Jax implementation:

d0 = {'z': 'z', 'a': 'a'}
d1 = collections.OrderedDict(d0)

assert d0 == d1  # Works
assert jax.tree_leaves(d0) == jax.tree_leaves(d1)  # << AssertionError: Oups ['z', 'a'] != ['a', 'z']

# By comparison, DM `tree` / tf.nest works as expected:
assert tree.flatten(d0) == tree.flatten(d1)  # Works: ['a', 'z'] == ['a', 'z']
jakevdp commented 2 years ago

That makes sense, thanks. I'd missed that comment from a few years ago.

jakevdp commented 2 years ago

It seems that there's broad agreement here that this should be fixed – we just need someone to take on the project.

XuehaiPan commented 2 years ago

Could we add a flag (e.g. a global variable) to let the user decide whether to sort the keys or not? For example:

jax.tree_util.dict_key_sorted(True)  # default behavior
jax.tree_util.dict_key_sorted(False)

In this issue, all the keys are strings, which are sortable. There is another issue about dict key sorting #11871. For a general PyTree, the keys are not always comparable:

tree = {1: '1', 'a': 'a'}

sorted(tree)  # <- TypeError: '<' not supported between instances of 'str' and 'int'
XuehaiPan commented 1 year ago

Could we add a flag (e.g. a global variable) to let the user decide whether to sort the keys or not? For example: (breaks referential transparency)

jax.tree_util.dict_key_sorted(True)  # default behavior
jax.tree_util.dict_key_sorted(False)

The problem is, "fixing" this is not as easy as just removing that sort. There are many parts of JAX that rely on equality comparisons of the flattened representation of dicts, and if you preserve insertion order in flattening, then d1 == d2 no longer implies that tree_flatten(d1) == tree_flatten(d2), which has deep and subtle implications in the implementation of JAX transforms throughout the package.

For that reason, it's not clear to me whether this should be considered a bug, or just the way that flattening works in JAX (and it's why nobody as of yet has been eager to attempt making this change).

I agree, referential transparency should be a key feature for the pytree utilities: equal inputs implies equal outputs. However, the current implementation always sorts the key order and returns a new sorted dict after unfattening. Nowadays, many Python code rely on the builtins.dict is guaranteed insertion order. This may cause potential bugs that many people do not aware this behavior in JAX pytree (sorted keys after tree_unflatten).

d = {'b': 2, 'a': 1}
# Map with the identity function changes the key order
out = jax.tree_util.tree_map(lambda x: x, d)  # => {'a': 1, 'b': 2}
d == out  # => True
list(d) == list(out)  # => False  ['b', 'a'] != ['a', 'b']

For example, use tree_map to process kwarges (PEP 468 – Preserving the order of **kwargs in a function.):

def func(*args, **kwargs):
    args, kwargs = jax.tree_util.tree_map(do_something, (args, kwargs))  # changes key order in kwargs
    ...
In [1]: import jax

In [2]: from typing import NamedTuple

In [3]: class Ints(NamedTuple):
   ...:     foo: int
   ...:     bar: int
   ...:     

In [4]: Ints(1, 2)
Out[4]: Ints(foo=1, bar=2)

In [5]: Ints(1, 2).foo
Out[5]: 1

In [6]: Ints.__annotations__
Out[6]: {'foo': <class 'int'>, 'bar': <class 'int'>}

In [7]: Floats = NamedTuple('Floats', **jax.tree_util.tree_map(lambda ann: float, Ints.__annotations__))

In [8]: Floats(1.0, 2.0)
Out[8]: Floats(bar=1.0, foo=2.0)

In [9]: Floats(1.0, 2.0).foo
Out[9]: 2.0

One solution is to store the input dict keys in insertion order in Node during flatten, and update the PyTreeDef.unflatten method to respect the key order while reconstructing the output pytree.

leaves, treedef = jax.tree_util.tree_flatten({'b': 2, 'a': 1})
leaves   # [1, 2]
treedef  # PyTreeDef({'a': *, 'b': *})
treedef.unflatten([11, 22])  # {'b': 22, 'a': 11} # respect original key order

Ref:

Gattocrucco commented 1 year ago

Commenting to add that I just encountered this behavior and I find it quite annoying.

If I was to implement this, I'd use as treedef a dictionary with the same keys but filled with None values. This way the implementation would completely piggyback on Python and remain consistent under all circumstances.

carlosgmartin commented 4 days ago

Any update on this? I have a similar issue with jax.jit:

$ python3 -c "import jax; print(jax.jit(lambda: {'b': None, 'a': None})())"
{'a': None, 'b': None}