Open chiamp opened 1 year ago
@cgarciae @marcvanzee
As discussed internally, we have to make MutableDict
a proper pytree. Here is an idea for the implementation:
from typing import Iterable, Mapping, MutableMapping, TypeVar, Union
import flax
import jax
A = TypeVar('A')
B = TypeVar('B')
class MutableDict(MutableMapping[A, B]):
def __init__(
self,
input: Union[Mapping[A, B], Iterable[tuple[A, B]], None] = None,
/,
**kwargs: B,
):
self._dict: dict[A, B] = dict(input, **kwargs) if input else dict(**kwargs)
def __setitem__(self, key: A, value: B) -> None:
self._dict[key] = value
def __getitem__(self, key: A) -> B:
value = self._dict[key]
if isinstance(value, dict) and not isinstance(value, MutableDict):
return MutableDict(value) # type: ignore
return value
def __delitem__(self, key: A) -> None:
del self._dict[key]
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return 'MutableDict(' + flax.core.pretty_repr(self._dict) + ')'
jax.tree_util.register_pytree_with_keys(
MutableDict,
lambda d: (
tuple(
(jax.tree_util.DictKey(key), value)
for key, value in d._dict.items()
),
tuple(d._dict.keys()),
),
lambda keys, values: MutableDict(zip(keys, values)),
lambda d: ((d._dict,), None),
)
d = MutableDict({'a': 1, 'b': {'c': 2, 'd': 3}})
print('\nprint\n--------------------------')
print(d)
print('\naccess\n--------------------------')
print(d['b'])
print('\ntree_flatten\n--------------------------')
print(jax.tree_util.tree_flatten(d)[0])
print('\ntree_flatten_with_path\n--------------------------')
print(jax.tree_util.tree_flatten_with_path(d)[0])
print
--------------------------
MutableDict({
a: 1,
b: {
c: 2,
d: 3,
},
})
access
--------------------------
MutableDict({
c: 2,
d: 3,
})
tree_flatten
--------------------------
[1, 2, 3]
tree_flatten_with_path
--------------------------
[((DictKey(key='a'),), 1), ((DictKey(key='b'), DictKey(key='c')), 2), ((DictKey(key='b'), DictKey(key='d')), 3)]
I decided not to inherit from dict
as it leads to some optimizations but instead implemented the MutableMappging
protocol.
We might need to do some deep checks so we don't have nested MutableDicts
similar to what FrozenDict
has.
I am wondering if we should just promote/expose flax.core.pretty_print
as nn.pretty_print
and use it in the guides so pick it up?
After the dict migration, Flax now returns regular dicts when calling the
.init
,.init_with_output
and.apply
Module methods. However the representation of regular dicts are not as readable compared to the indented version of FrozenDicts.Regular dicts:
FrozenDicts:
The indented representation can be viewed by calling
flax.core.pretty_repr
on the dict. Alternatively we could subclass dict and override the__repr__
method to return an indented representation and have Flax return this subclass when.init
,.init_with_output
and.apply
are called:Another option is to add a section in the dict migration guide to let users know they can get the indented representation by calling
flax.core.pretty_repr
(although this currently works only on FrozenDicts and regular dicts, and not other objects like TrainState).