Open jheagerty opened 1 year ago
And sorry, for context, the checkpointing is so that I can implement self-play where the baseline model that controls the enemy is updated to match the model we're training every once in a while.
Never mind, figured out the main thing for me, checkpointing. You have to:
My one concern is that I doubt/cannot tell whether the learning rate scheduler is maintained, but I will worry about that later.
On less regular metrics, I will also worry about that later, but I've seen some likely jittable tools.
@jheagerty actually I think you can save checkpoints under jit easily with callbacks, such as jax.experimental.io_callback() (for example inside _update_step
to save after the each update)
yep! I do it the way @Howuhh is describing. If you look at the code, there is the debug callback. You can just replace the print function with your checkpointing and wandb logging.
Thanks so much! Will look into this
The way you do it should be fine too, and is arguably better (though takes more code). The optimizer parameters (which includes the lr scheduling info) should be in the train_state
@luchris429 Hi Chris! I was wondering, how did you implement restoring of checkpoints in the PureJaxRL end-to-end jitting? I'm able to save checkpoints pretty easily with a debug callback function, but I can't quite figure out how to restore. I attempted to put a experimental.io_callback
function in the train
function but I can't actually do anything with the string checkpoint path because JAX can't handle strings.
You can try to load the runner state here!
Does it not work if you set the filename in the config?
You can try to load the runner state here!
Does it not work if you set the filename in the config?
So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):
def resuming_callback(path):
checkpointer = ocp.PyTreeCheckpointer()
raw_restored = checkpointer.restore(path)
return raw_restored
runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
raw_restored = io_callback(
resuming_callback, runner_state, args.resume_checkpoint_path
)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, args.num_iterations
)
However JAX errors out with the complaint that my args.resume_checkpoint_path
is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?
Sorry I didn't catch this message! I hope you've figured it out.
I think you need to make sure it's a static argument since you can't JIT a string as an argument.
You can try to load the runner state here! Does it not work if you set the filename in the config?
So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):
def resuming_callback(path): checkpointer = ocp.PyTreeCheckpointer() raw_restored = checkpointer.restore(path) return raw_restored runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key) if args.resume: raw_restored = io_callback( resuming_callback, runner_state, args.resume_checkpoint_path ) runner_state, metric = jax.lax.scan( _update_step, runner_state, None, args.num_iterations )
However JAX errors out with the complaint that my
args.resume_checkpoint_path
is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?
actually, I have tried with failure result even a fixed filename. Restoration in jit is quite difficult for me
I know this sound ridiculous but I've spent ages trying to implement checkpoint saving into your example/walkthrough training code and have been getting nowhere.
Similarly (as it's something to do every n steps or every epoch) I've been trying to reduce the frequency of metric collection, as it has been giving me VRAM errors with my lowly NVIDIA 3080.
Any advice / solutions would be very gratefully received.