keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.13k stars 19.49k forks source link

Fix handling of `OrderedDict` with optree and related documentation. #20481

Open hertschuh opened 2 weeks ago

hertschuh commented 2 weeks ago

The tree API had specific but contradicting documentation calling out the handling of OrderedDicts. However, the behavior of the optree implementation did not honor this documentation (using the key order, not the sequence order) for flatten, although it did for pack_sequence_as. The result was that not only did flatten not behave the same with optree and dm-tree, but also pack_sequence_as(flatten(...)) was not idempotent. The optree implementation did have all the machinery needed to handle OrderedDicts per spec, which was used for pack_sequence_as, but not flatten. This also fixes the discrepancy in the behavior for namedtuples.

codecov-commenter commented 2 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 76.40%. Comparing base (b0b9d04) to head (def69ec).

:exclamation: There is a different number of reports uploaded between BASE (b0b9d04) and HEAD (def69ec). Click for more details.

HEAD has 2 uploads less than BASE | Flag | BASE (b0b9d04) | HEAD (def69ec) | |------|------|------| |keras|4|3| |keras-torch|1|0|
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #20481 +/- ## ========================================== - Coverage 82.07% 76.40% -5.68% ========================================== Files 515 515 Lines 47504 47512 +8 Branches 7454 7457 +3 ========================================== - Hits 38991 36303 -2688 - Misses 6703 9452 +2749 + Partials 1810 1757 -53 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras/pull/20481/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras](https://app.codecov.io/gh/keras-team/keras/pull/20481/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `76.33% <100.00%> (-5.60%)` | :arrow_down: | | [keras-jax](https://app.codecov.io/gh/keras-team/keras/pull/20481/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `65.02% <100.00%> (+<0.01%)` | :arrow_up: | | [keras-numpy](https://app.codecov.io/gh/keras-team/keras/pull/20481/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `59.98% <100.00%> (+<0.01%)` | :arrow_up: | | [keras-tensorflow](https://app.codecov.io/gh/keras-team/keras/pull/20481/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `66.04% <100.00%> (+<0.01%)` | :arrow_up: | | [keras-torch](https://app.codecov.io/gh/keras-team/keras/pull/20481/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `?` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

nicolaspi commented 1 week ago

Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?

e.g.:

from collections import OrderedDict
import optree
from keras import tree

class WrappedOrderedDict(OrderedDict):
  pass

def flatten(d):
  values = []
  keys = []
  for key in sorted(d.keys()):
    values.append(d[key])
    keys.append(key)
  return values, list(d.keys()), keys

def unflatten(metadata, children):
  index = {key: i for i, key in enumerate(sorted(metadata))}
  return OrderedDict({key: children[index[key]] for key in metadata})

optree.register_pytree_node(
    WrappedOrderedDict,
    flatten,
    unflatten,
    namespace="keras",
)

def ordereddict_pytree_test():
  # Create an OrderedDict with deliberately unsorted keys
  ordered_d = OrderedDict([('c', 3), ('a', 1), ('b', 2)])

  def wrap(s):
    if isinstance(s, OrderedDict):
      return WrappedOrderedDict(s)
    return None

  def unwrap(s):
    if isinstance(s, WrappedOrderedDict):
      return OrderedDict(s)
    return None

  d = tree.traverse(wrap, ordered_d, top_down=False)
  flat_d = tree.flatten(d)
  flat_d_paths = tree.flatten_with_path(d)
  assert flat_d == [1, 2, 3]
  assert [p[0] for p, v in flat_d_paths] == ["a", "b", "c"]

  tree_struct = tree.traverse(wrap, ordered_d, top_down=False)
  wrapped_d = tree.pack_sequence_as(tree_struct, flat_d)
  orig_d = tree.traverse(unwrap, wrapped_d, top_down=False)
  assert isinstance(orig_d, OrderedDict)
  assert list(orig_d.keys()) == list(ordered_d.keys())
  assert list(orig_d.values()) == list(ordered_d.values())

ordereddict_pytree_test()
hertschuh commented 1 week ago

Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?

Hi Nicolas,

Thank you for the suggestion. I actually completely scratched this PR and decided to use a different approach. The optree behavior will be the reference behavior. The goal is indeed to maximize the use of the C++ implementation of optree since it is the default and dm-tree is only a fallback.