Closed OmaymaMahjoub closed 10 months ago
This is brilliant, @OmaymaMahjoub! The proposed changes (primarily where we reload the params in the code) make the checkpointing much cleaner—thank you.
A few ideas:
We can now remove the following block: https://github.com/instadeepai/Mava/blob/15b2106c1a48a689bba2ba19e94ed060f3d1fbe0/mava/utils/checkpointing.py#L142-L145
We now reconstruct the learner state manually, which means we don't have to worry about saving the type of our state (FF or RNN). I like this, as the alternative always felt a bit hacky :)
- We probably shouldn't call our method
restore_learner_state
anymore, since it returns params only (i.e. not opt_state etc.) I think this is good, but let's rename to match? I feel thatrestore_params
is fine (even thoughhstates
are also restored sometimes), but open to other ideas!Along with above, I'm wondering if we should even be checkpointing the entire learner state (
unreplicated_learner_state
), instead of just the params/hstates?
- We shouldn't: because we're only loading the params anyway, so why save unnecessary stuff like opt_state?
- We should: because in the future we might want to load the opt_state again, or just have more info at hand. Thoughts?
Otherwise, great stuff! Thanks agian @OmaymaMahjoub 🚀
Thanks @callumtilbury for the review! I addressed the first two points. For the third point yeah I kept the storing of learning_state as I thought maybe in the future it will be helpful however keen to hear the others' thoughts on that!
What?
Fix the checkpointer to accept
dictconfig
and load nonFrozenDict
paramsWhy?
Currently, when trying to add metadata config we will have an error as the config are not dict, besides an error when loading checkpointed params as the params type is not FrozenDict in the systems.
How?
update_batch_size
during storing