google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.17k stars 648 forks source link

Issues restoring checkpoint of struct.dataclass w/ FrozenDict attr #676

Open rwightman opened 4 years ago

rwightman commented 4 years ago

Working from a modified ImageNet Linen example, I've added two state attr for Polyak averaging ema values as so

@flax.struct.dataclass
class TrainState:
    step: int
    optimizer: flax.optim.Optimizer
    model_state: Any
    dynamic_scale: flax.optim.DynamicScale
    ema_params: flax.core.FrozenDict = None  # lazy init on first step
    ema_model_state: flax.core.FrozenDict = None   # lazy init on first step

Restoring the checkpoints with that state causes an error as the FrozenDicts get restored as dicts. I'm not sure if this is a bug or feature request (ie is this expected). I noticed there is registration fn for restoring state dict, FrozenDicts are among them, should that not cover this case? Or should I wrap my ema state in another class and register my own state dict restore fn that freezes the dicts.

I'm currently doing this hack after restore to work around the issue...

    if step_offset > 0:
        state = state.replace(
            ema_params=flax.core.freeze(state.ema_params),
            ema_model_state=flax.core.freeze(state.ema_model_state))
avital commented 4 years ago

I think most likely this is simply an oversight -- dataclasses should restore FrozenDicts. We should probably have a unit test serializing and deserializing a FrozenDict to make sure it comes out as a FrozenDict as well.

Roping in @jheek as he's been doing some changes to FrozenDict lately.

avital commented 4 years ago

@jheek -- if this fix doesn't require deep understanding, maybe best to mark it as "pull requests encouraged" as looks like @rwightman has a workaround for now.

jheek commented 4 years ago

This is not a trivial problem actually. We restore attributes based on the original type. But when it is None we cannot deduce it. I think this would just work if you use init_train_state = TrainState(..., ema_params=FrozenDict(),ema_model_state=FrozenDict())

rwightman commented 4 years ago

@jheek ah, k... so in thise case yeah, the empty dict acts as a sufficient 'not-initialized' truthy value that the rest of my lazy logic should still work.

rwightman commented 4 years ago

@jheek I thought this would be a quick and easy fix but ended up going down a rabbit hole. The idea doesn't work.

You cannot restore a FrozenDict with state and target having different keys, in this case no keys in the target, based on the way the FrozenDict restore works. For the typical use case I guess that makes sense but it's a non-obvious, silent failure in this case (you just end up restoring an empty Dict when there was one in the checkpoint)... confusing.

EDIT: So I guess I'm back at the cleanest path forward being to implement my own EmaUpdater() class holding the two optional FrozenDicts that has a default False truthy state when not initialized with params + state. I write my own to/from state dict methods for that class, register them, and I avoid calling the normal FrozenDict methods...

jheek commented 4 years ago

Yes I think that is more clean. Another alternative if you don't want to register a bunch of classes is to use restore_checkpoint(target=None) which gives you the raw state dict. This state dict you can "pre-process" whatever we you like

rwightman commented 4 years ago

@jheek I created an EmaState dataclass and got that working in a less hacky fashion. Still have a bit of an issue, allowing training to start with ema active and then disabling or other way around seems to require custom serialization due to the way None values are handled.

So a question about handling None in either serialized state or target. Is the current behaviour ever correct or useful? On deseriization, if a target value is None the type isn't determined properly and it just dumps the dict of serialized state into target. If the state value is None it crashes trying to iterate over the state.

Wouldn't None be more useful as a 'do not deserialized' sentinel? If a target value is None it would not try to deserialize and restore that field. If the serialized state for that field was None, it would leave the current target unchanged. That seems to be much more useful as a pattern that would allow some natural checkpoint fwd/bwd compat or toggling of active state fields between sessions.

If there is a good reason not to use None for the described functionality. Would an explicit type make sense?

@struct.dataclass
class TrainState
   blahblah: SomeOtherState = _UNUSED()