Open rwightman opened 4 years ago
I think most likely this is simply an oversight -- dataclasses should restore FrozenDicts. We should probably have a unit test serializing and deserializing a FrozenDict to make sure it comes out as a FrozenDict as well.
Roping in @jheek as he's been doing some changes to FrozenDict lately.
@jheek -- if this fix doesn't require deep understanding, maybe best to mark it as "pull requests encouraged" as looks like @rwightman has a workaround for now.
This is not a trivial problem actually. We restore attributes based on the original type. But when it is None
we cannot deduce it. I think this would just work if you use
init_train_state = TrainState(..., ema_params=FrozenDict(),ema_model_state=FrozenDict())
@jheek ah, k... so in thise case yeah, the empty dict acts as a sufficient 'not-initialized' truthy value that the rest of my lazy logic should still work.
@jheek I thought this would be a quick and easy fix but ended up going down a rabbit hole. The idea doesn't work.
You cannot restore a FrozenDict with state and target having different keys, in this case no keys in the target, based on the way the FrozenDict restore works. For the typical use case I guess that makes sense but it's a non-obvious, silent failure in this case (you just end up restoring an empty Dict when there was one in the checkpoint)... confusing.
EDIT: So I guess I'm back at the cleanest path forward being to implement my own EmaUpdater() class holding the two optional FrozenDicts that has a default False truthy state when not initialized with params + state. I write my own to/from state dict methods for that class, register them, and I avoid calling the normal FrozenDict methods...
Yes I think that is more clean. Another alternative if you don't want to register a bunch of classes is to use restore_checkpoint(target=None)
which gives you the raw state dict. This state dict you can "pre-process" whatever we you like
@jheek I created an EmaState dataclass and got that working in a less hacky fashion. Still have a bit of an issue, allowing training to start with ema active and then disabling or other way around seems to require custom serialization due to the way None values are handled.
So a question about handling None in either serialized state or target. Is the current behaviour ever correct or useful? On deseriization, if a target value is None the type isn't determined properly and it just dumps the dict of serialized state into target. If the state value is None it crashes trying to iterate over the state.
Wouldn't None be more useful as a 'do not deserialized' sentinel? If a target value is None it would not try to deserialize and restore that field. If the serialized state for that field was None, it would leave the current target unchanged. That seems to be much more useful as a pattern that would allow some natural checkpoint fwd/bwd compat or toggling of active state fields between sessions.
If there is a good reason not to use None for the described functionality. Would an explicit type make sense?
@struct.dataclass
class TrainState
blahblah: SomeOtherState = _UNUSED()
Working from a modified ImageNet Linen example, I've added two state attr for Polyak averaging ema values as so
Restoring the checkpoints with that state causes an error as the FrozenDicts get restored as dicts. I'm not sure if this is a bug or feature request (ie is this expected). I noticed there is registration fn for restoring state dict, FrozenDicts are among them, should that not cover this case? Or should I wrap my ema state in another class and register my own state dict restore fn that freezes the dicts.
I'm currently doing this hack after restore to work around the issue...