nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

About continue training #7

Closed VitaLemonTea1 closed 7 months ago

VitaLemonTea1 commented 7 months ago

Hi,

I get another question. After I trianed the model for 1000000steps, I stop the training and do some other works. So, I plan to continue training today on the checkpoint at 1,000,000 steps.

python train.py task=mt80 model_size=1 batch_size=1024 checkpoint="/Project/TD——MPC/tdmpc2/logs/mt80/1/default/models/1000000.pt"

As it train for half a day, I realized that it didn't continue training from the checkpoint, but instead started a new training session from the beginning. So I wander know how can I continue training with 1000000checkpoint.

Thanks!

nicklashansen commented 7 months ago

train.py currently does not use the checkpoint argument, only evaluate.py does. This seems like a very reasonable request though, I will issue a commit soon with this functionality.

Copying this line https://github.com/nicklashansen/tdmpc2/blob/f3139291e2dc8e47480184a4a1bce05e8980caa3/tdmpc2/evaluate.py#L59 from evaluate.py to train.py (potentially with a few extra checks / warnings) should do the trick.

One thing to note though is that agent.save() only stores the model weights at the moment, not the optimizer state which would be needed for seamless resuming of the training run. I can add that to the checkpoint files, but it will roughly double the file size.