instadeepai / Mava

šŸ¦ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
737 stars 90 forks source link

feat: Model Checkpointing #954

Closed callumtilbury closed 11 months ago

callumtilbury commented 11 months ago

What?

Ability to checkpoint the learner_state using orbax as a base.

Why?

Important for our research in multiple dimensions.

How?

Built in the same way as for internal CityLearn code, wrapping orbax's functionality.

Extra

This PR adds simple checkpointing to each system, but we can remove that by default, if preferred. It does not include any checkpoint reloading, but that can be done in the following way (example for Rec MAPPO with RWARE):

loaded_state = Checkpointer(
        model_name="rec_mappo_rware",
        config=config,
        timestamp_override="20231127155319",
).restore_learner_state()

loaded_state = jax.device_put_replicated(loaded_state, jax.devices())

learner_state = learner_state._replace(
        params=Params._make(FrozenDict(loaded_state['params']).values()),
        hstates=HiddenStates._make(FrozenDict(loaded_state['hstates']).values())
)

But I do want to think of cleaner ways to reloadā€”perhaps for a later PR?

callumtilbury commented 11 months ago

Open Q: should I add checkpointing options to config, or only later, when we find a need?

OmaymaMahjoub commented 11 months ago

Open Q: should I add checkpointing options to config, or only later, when we find a need?

Thanks for the PR @callumtilbury, in my opinion, I think it will be better to make it optional and also make the model_name a config param