ShaneFlandermeyer / tdmpc2-jax

Jax/Flax Implementation of TD-MPC2
43 stars 6 forks source link

Saving trained model #3

Closed Ozzey closed 4 months ago

Ozzey commented 4 months ago

Is there any plans on adding support for saving single/multi-task trained models and continued training of these models?

ShaneFlandermeyer commented 4 months ago

I plan to add a save/load feature in the not-so-distant future. If you would like to do it right now, you should be able to add Orbax checkpoints into the main training loop without much trouble.

I only consider single-task scenarios in my research, so multi-task models are low on my list of priorities. I would gladly welcome any pull requests if that's a highly desired feature though :)

Ozzey commented 4 months ago

Great! I am working on making your JAX implementation of TDMPC2 more OOP based which would have a structure similar to SB3 algorithms. I'll try to add multitasking and Checkpoints and then maybe you can test it if possible?

ShaneFlandermeyer commented 4 months ago

Thanks to some recent contributions in #4, we now have checkpointing and tensorboard logging in the develop branch! Closing this issue now. Feel free to open another if you would like to continue discussing multi-task learning.