google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.07k stars 642 forks source link

Regular dict representation not indented #3280

Open chiamp opened 1 year ago

chiamp commented 1 year ago

After the dict migration, Flax now returns regular dicts when calling the .init, .init_with_output and .apply Module methods. However the representation of regular dicts are not as readable compared to the indented version of FrozenDicts.

Regular dicts:

class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(5)(x)
    x = nn.relu(x)
    return x
x = jnp.ones((1,3))
model = MLP()
params = model.init(jax.random.PRNGKey(0), x)['params']
state = TrainState.create(apply_fn=model.apply, params=params, tx=optax.adam(1e-3))
state
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params={'Dense_0': {'kernel': Array([[ 0.37229332, -0.4265755 , -1.1151816 , -0.09558704, -0.62169886],
       [-1.060781  ,  1.0546707 ,  0.33051118, -0.7090655 ,  0.37682843],
       [-0.30747807, -0.39064118, -0.25515485,  0.5127583 , -0.5559202 ]],      dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fe6daeca4d0>, update=<function chain.<locals>.update_fn at 0x7fe6daeca9e0>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu={'Dense_0': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32), 'kernel': Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)}}, nu={'Dense_0': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32), 'kernel': Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)}}), EmptyState()))

FrozenDicts:

state = TrainState.create(apply_fn=model.apply, params=flax.core.freeze(params), tx=optax.adam(1e-3))
state
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params=FrozenDict({
    Dense_0: {
        kernel: Array([[ 0.37229332, -0.4265755 , -1.1151816 , -0.09558704, -0.62169886],
               [-1.060781  ,  1.0546707 ,  0.33051118, -0.7090655 ,  0.37682843],
               [-0.30747807, -0.39064118, -0.25515485,  0.5127583 , -0.5559202 ]],      dtype=float32),
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fe710e21ea0>, update=<function chain.<locals>.update_fn at 0x7fe710e225f0>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=FrozenDict({
    Dense_0: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
        kernel: Array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.]], dtype=float32),
    },
}), nu=FrozenDict({
    Dense_0: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
        kernel: Array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.]], dtype=float32),
    },
})), EmptyState()))

The indented representation can be viewed by calling flax.core.pretty_repr on the dict. Alternatively we could subclass dict and override the __repr__ method to return an indented representation and have Flax return this subclass when .init, .init_with_output and .apply are called:

@flax.struct.dataclass
class MutableDict(dict):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
  def __repr__(self):
    return 'MutableDict' + flax.core.pretty_repr(self)

state = TrainState.create(apply_fn=model.apply, params=MutableDict(params), tx=optax.adam(1e-3))
state
TrainState(step=0, apply_fn=<bound method Module.apply of MLP()>, params=MutableDict{
    Dense_0: {
        kernel: Array([[ 0.37229332, -0.4265755 , -1.1151816 , -0.09558704, -0.62169886],
               [-1.060781  ,  1.0546707 ,  0.33051118, -0.7090655 ,  0.37682843],
               [-0.30747807, -0.39064118, -0.25515485,  0.5127583 , -0.5559202 ]],      dtype=float32),
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fe710e3a5f0>, update=<function chain.<locals>.update_fn at 0x7fe710e39990>), opt_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=MutableDict{}, nu=MutableDict{}), EmptyState()))

Another option is to add a section in the dict migration guide to let users know they can get the indented representation by calling flax.core.pretty_repr (although this currently works only on FrozenDicts and regular dicts, and not other objects like TrainState).

chiamp commented 1 year ago

@cgarciae @marcvanzee

cgarciae commented 1 year ago

As discussed internally, we have to make MutableDict a proper pytree. Here is an idea for the implementation:

from typing import Iterable, Mapping, MutableMapping, TypeVar, Union

import flax
import jax

A = TypeVar('A')
B = TypeVar('B')

class MutableDict(MutableMapping[A, B]):

  def __init__(
      self,
      input: Union[Mapping[A, B], Iterable[tuple[A, B]], None] = None,
      /,
      **kwargs: B,
  ):
    self._dict: dict[A, B] = dict(input, **kwargs) if input else dict(**kwargs)

  def __setitem__(self, key: A, value: B) -> None:
    self._dict[key] = value

  def __getitem__(self, key: A) -> B:
    value = self._dict[key]
    if isinstance(value, dict) and not isinstance(value, MutableDict):
      return MutableDict(value)  # type: ignore
    return value

  def __delitem__(self, key: A) -> None:
    del self._dict[key]

  def __iter__(self):
    return iter(self._dict)

  def __len__(self):
    return len(self._dict)

  def __repr__(self):
    return 'MutableDict(' + flax.core.pretty_repr(self._dict) + ')'

jax.tree_util.register_pytree_with_keys(
    MutableDict,
    lambda d: (
        tuple(
            (jax.tree_util.DictKey(key), value)
            for key, value in d._dict.items()
        ),
        tuple(d._dict.keys()),
    ),
    lambda keys, values: MutableDict(zip(keys, values)),
    lambda d: ((d._dict,), None),
)

d = MutableDict({'a': 1, 'b': {'c': 2, 'd': 3}})

print('\nprint\n--------------------------')
print(d)
print('\naccess\n--------------------------')
print(d['b'])
print('\ntree_flatten\n--------------------------')
print(jax.tree_util.tree_flatten(d)[0])
print('\ntree_flatten_with_path\n--------------------------')
print(jax.tree_util.tree_flatten_with_path(d)[0])
print
--------------------------
MutableDict({
    a: 1,
    b: {
        c: 2,
        d: 3,
    },
})

access
--------------------------
MutableDict({
    c: 2,
    d: 3,
})

tree_flatten
--------------------------
[1, 2, 3]

tree_flatten_with_path
--------------------------
[((DictKey(key='a'),), 1), ((DictKey(key='b'), DictKey(key='c')), 2), ((DictKey(key='b'), DictKey(key='d')), 3)]

I decided not to inherit from dict as it leads to some optimizations but instead implemented the MutableMappging protocol.

cgarciae commented 1 year ago

We might need to do some deep checks so we don't have nested MutableDicts similar to what FrozenDict has. I am wondering if we should just promote/expose flax.core.pretty_print as nn.pretty_print and use it in the guides so pick it up?