Open hertschuh opened 2 weeks ago
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 76.40%. Comparing base (
b0b9d04
) to head (def69ec
).:exclamation: There is a different number of reports uploaded between BASE (b0b9d04) and HEAD (def69ec). Click for more details.
HEAD has 2 uploads less than BASE
| Flag | BASE (b0b9d04) | HEAD (def69ec) | |------|------|------| |keras|4|3| |keras-torch|1|0|
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Wouldn't it be better to wrap OrderedDict
rather than re-implement flatten
, which is written in C++ in optree?
e.g.:
from collections import OrderedDict
import optree
from keras import tree
class WrappedOrderedDict(OrderedDict):
pass
def flatten(d):
values = []
keys = []
for key in sorted(d.keys()):
values.append(d[key])
keys.append(key)
return values, list(d.keys()), keys
def unflatten(metadata, children):
index = {key: i for i, key in enumerate(sorted(metadata))}
return OrderedDict({key: children[index[key]] for key in metadata})
optree.register_pytree_node(
WrappedOrderedDict,
flatten,
unflatten,
namespace="keras",
)
def ordereddict_pytree_test():
# Create an OrderedDict with deliberately unsorted keys
ordered_d = OrderedDict([('c', 3), ('a', 1), ('b', 2)])
def wrap(s):
if isinstance(s, OrderedDict):
return WrappedOrderedDict(s)
return None
def unwrap(s):
if isinstance(s, WrappedOrderedDict):
return OrderedDict(s)
return None
d = tree.traverse(wrap, ordered_d, top_down=False)
flat_d = tree.flatten(d)
flat_d_paths = tree.flatten_with_path(d)
assert flat_d == [1, 2, 3]
assert [p[0] for p, v in flat_d_paths] == ["a", "b", "c"]
tree_struct = tree.traverse(wrap, ordered_d, top_down=False)
wrapped_d = tree.pack_sequence_as(tree_struct, flat_d)
orig_d = tree.traverse(unwrap, wrapped_d, top_down=False)
assert isinstance(orig_d, OrderedDict)
assert list(orig_d.keys()) == list(ordered_d.keys())
assert list(orig_d.values()) == list(ordered_d.values())
ordereddict_pytree_test()
Wouldn't it be better to wrap
OrderedDict
rather than re-implementflatten
, which is written in C++ in optree?
Hi Nicolas,
Thank you for the suggestion. I actually completely scratched this PR and decided to use a different approach. The optree
behavior will be the reference behavior. The goal is indeed to maximize the use of the C++ implementation of optree
since it is the default and dm-tree
is only a fallback.
The tree API had specific but contradicting documentation calling out the handling of
OrderedDict
s. However, the behavior of theoptree
implementation did not honor this documentation (using the key order, not the sequence order) forflatten
, although it did forpack_sequence_as
. The result was that not only didflatten
not behave the same withoptree
anddm-tree
, but alsopack_sequence_as(flatten(...))
was not idempotent. Theoptree
implementation did have all the machinery needed to handleOrderedDict
s per spec, which was used forpack_sequence_as
, but notflatten
. This also fixes the discrepancy in the behavior fornamedtuple
s.flatten
andpack_sequence_as
related to the handling ofOrderedDict
s.unflatten_as
, which doesn't exist.if optree
tests intree_test.py
, which should not exist for consistency betweenoptree
anddm-tree
.flatten_with_path
.tree
instead ofkeras.tree
in unit test.dm-tree
uninstalled.