google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.93k stars 628 forks source link

Should Flax return FrozenDicts or regular dicts? #1223

Closed marcvanzee closed 1 year ago

marcvanzee commented 3 years ago

This topic is discussed regularly internally, and I feel we haven't reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.

Arguments in favor of FrozenDict

def f(params):
  params['conv1']['weight'] = ...
  return ...some computation over params

params = load_from_checkpoint()
print(f(params))
# now what is the value of params['conv1']['weight']?
# depending on whether f is jitted or not, you'd get different results

Arguments in favor of regular dicts

n2cholas commented 3 years ago

I think the Python saying "We're all consenting adults here" is pretty fitting. In my view, trading convenience for safety is reasonable here because JAX users should know (or will quickly come to learn) that under the JAX transformations, they should not mutate state. Since FrozenDicts are not as ergonomic as normal dicts, I tend to unfreeze them as soon as they're returned from init anyway.

Though I would prefer the user-facing API to just use dict, I wouldn't mind FrozenDict if the behaviour was closer to dict for non-mutating cases. In particular,

Explicit state management is one of my favourite aspects of Flax, as it gives me the ability to transparently manipulate modules/parameters without worrying about hidden side effects. I totally agree with @lucasb-eyer's point that it's counterproductive to provide explicit state without allowing the user to fully control it.

jheek commented 3 years ago

I think the Python saying "We're all consenting adults here" is pretty fitting

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.

That said, I don't think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn't avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function.

I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:

def clone_pytree(xs):
  # cloning is just an identity mapping
  return jax.tree_map(lambda x: x, xs)

def some_nested_transformation():
  my_copy = flax.traverse_util.clone_pytree(variables)
  my_copy['batch_stats']['x'] += 2.
  return my_copy

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

n2cholas commented 3 years ago

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It's a hard issue to fix though because mutability "infects" all of your code and Python isn't a functional language.

This is a good point, especially as a codebase grows it can sneak past you. ​Personally, if FrozenDict better matched the ergonomics of dict outside of mutation, I would not see it as a burden, at least for my own use cases.

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

I actually quite like the immutable dataclasses, since the .replace(...) API is similar namedtuples. In my view, the inconveniences that arise with FrozenDict don't happen here since dataclasses don't have arbitrary structure and you don't generally manipulate the that structure.

marcvanzee commented 3 years ago

Thanks for the input @n2cholas! After chatting with @jheek offline, the consensus is that it is indeed useful to return regular dicts, but that we block implementing this on merging the chex and flax dataclasses.

lucasb-eyer commented 3 years ago

Gear, very happy about this decision. I'd just like to add that

See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link).

Is a complete red herring. This is about hidden global state, whereas this discussion is specifically about explicit, non-global state. It's actually more about rng design than anything else, and what we are talking about doing here is already the "better" rng design where the user explicitly is given, and trusted to correctly handle, the state.

PhilipVinc commented 3 years ago

Is there any further development on this?

marcvanzee commented 3 years ago

Sorry for the delay -- I was on parental leave.

@jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses?

NeilGirdhar commented 2 years ago

What does merging the dataclasses consist of? Are flax dataclasses going to be inheriting the mapping interface?

jheek commented 2 years ago

The merging of dataclasses is taking much longer than originally anticipated. I'll bring this up in our next sync meeting because I think we should start to move towards allowing mutability independently of actually merging the implementations witch chex

NeilGirdhar commented 2 years ago

I think we should start to move towards allowing mutability

Sorry, but why would you do that?

The merging of dataclasses is taking much longer than originally anticipated. I

Also, I stil don't understand what this merge will consist of. Flax's dataclasses are well-designed: They are just frozen dataclasses that register as pytrees, have a field function that conveniently supports marking static fields, and add a replace method. Besides the replace method (which is just a shortcut to dataclasses.replace), this is a minimal interface.

Chex datacasses are badly designed: they are not frozen, they can't mark static fields, and they unnecessarily expose the whole mapping interface, which means you can access fields as attributes or keys. They also expose a to_tuple method that is inferior to dataclasses.astuple, which supports nested dataclasses. The from_tuple method is also somewhat flimsy since it won't work with Python 3.10's new keyword-only arguments. This is not a minimal interface.

I was hoping to ditch tjax's dataclasses in favor of flax's, but if you're merging in any of chex's behavior, I won't be able to.

jheek commented 2 years ago

We won't be removing features like frozen, static fields, and replace. We do however want to be less strict about enforcing functional patterns. Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.

As for the mapping interface. This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types. This is also why chex cannot easily add static fields because tf.nest doesn't support it. We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.

NeilGirdhar commented 2 years ago

Many users find it difficult to deal with frozen dataclasses/dicts. At the end of the day Python is not a functional language and partially making it behave like one can be awkward.

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.

For statistics, in my 5500 line Jax project, I call replace 9 times. It may be slightly more awkward than writing to attributes, but I don't think it's worth giving up the safety of all of the methods on my dataclasses being verified to be pure.

This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types.

Instead having a gigantic interface and passing the dataclass d to tf.nest, can't users pass dataclasses.asdict(d)?

This is also why chex cannot easily add static fields because tf.nest doesn't support it.

I see. Why not create an asdict function that removes the keys corresponding to static fields? Or more conveniently, convince Tensorflow to check for an as_dynamic_dict method and call it in tf.nest?

We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types.

Yes! Thank you!

jheek commented 2 years ago

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a good idea.

Yes, this is the tradeoff we have to think about and this we will discuss this further before making a final decision.

avital commented 2 years ago

I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions (jit, grad, etc.) are functional. That may feel awkward, but I think Flax's idea to enforce that was a brilliant idea.

What is a particular form of this problem? Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc). I understand the need to ensure frozen data structures within modules (and we're not proposing this changes -- module.apply will still have a mutable argument and use FrozenDicts based on that). The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.

PhilipVinc commented 2 years ago

By the way, I also think that Flax returning frozen dictionaries is extremely annoying. Changing this behaviour would also address https://github.com/deepmind/optax/issues/160

Moreover, our (NetKet) users and students learning Jax/Flax find it often confusing why they keep getting this object that they have to melt to edit.

NeilGirdhar commented 2 years ago

The only proposed change that I am aware of is changing the signatures of module.init and module.apply to not return FrozenDicts.

Sorry, I'm not actually discussing the topic of the issue. I just noticed a comment about merging chex.dataclass, and I wanted some clarification on that.

What is a particular form of this problem?

I can't find the example, but I saw one with treex (which doesn't enforce frozen dataclasses) where someone was doing

def f(x):
    x.some_member = some_value
    return x

@jit
def g(...):
    ...
    x = f(x)  # if you forget to assign to x, you will get different behavior for the jitted and unjitted function. 

Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).

Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.

avital commented 2 years ago

Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc).

Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class.

@NeilGirdhar I just mean things like updating state and params and reporting metrics -- it's totally fine to directly manipulate the variables dict in the main training loop, and people have to jump through (IMHO unnecessary) hoops to achieve this: https://github.com/google/flax/issues/1729#issuecomment-995839207.

NeilGirdhar commented 2 years ago

@avital Fair enough. I need to learn Flax better before I can really suggest something. A couple other options:

An at operator that does this under the covers, so that you can write:

embedding = params['params']['Embed_0']['embedding']
norm = jnp.linalg.norm(embedding, axis=-1, keepdims=True)
new_params = params.at['params']['Embed_0']['embedding'].divide(norm + 1e-10)
state = state.replace(params=new_params)

The at operator would return a handle like the one in jax.index_ops.

Or maybe a context manager that provides the handle and automatically rolls it back in when it ends:

with state.unfreeze() as unfrozen_state:
    unfrozen_state.params['params']['Embed_0']['embedding'] /= (jnp.linalg.norm(state.params['params']['Embed_0']['embedding'], axis=-1, keepdims=True) + 1e-10)

You'd still be jumping through hoops, but it's just one hoop.

avital commented 2 years ago

The problem with any hoop isn't it's complexity -- it's that it's something you have to learn suddenly, when you "just wanted to try this one thing". So any hoop should be justified by the benefit it gives you (hopefully a lot). Maybe I'm just misunderstanding this but I never understood the benefit of having module.apply and module.init return FrozenDicts. (I've always been strongly in support of FrozenDicts inside modules, which happens internally as a function of the mutable argument to module.apply)

avital commented 2 years ago

I guess another way to put it -- if someone really wants immutable data structures, they can always do, e.g. FrozenDict(module.init(...)). So the question is: which default serves the users best?

lucasb-eyer commented 2 years ago

And the answer is just plain dict, at least for this user here :)

cgarciae commented 2 years ago

+1 for this! I have a lot of code that immediately calls .unfreeze() right after init and apply.

cgarciae commented 2 years ago

Hey @NeilGirdhar! I believe you're looking for this example from Treex's User Guide.

cgarciae commented 1 year ago

Since this would be a breaking change, we should bump Flax's version to avoid breaking OS user's using semantic versioning.

marcvanzee commented 1 year ago

FYI: @chiamp is going to look into this

chiamp commented 1 year ago

Closing after #3193 landed.