luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
720 stars 61 forks source link

Checkpointing and less regular metric collection #13

Open jheagerty opened 1 year ago

jheagerty commented 1 year ago

I know this sound ridiculous but I've spent ages trying to implement checkpoint saving into your example/walkthrough training code and have been getting nowhere.

Similarly (as it's something to do every n steps or every epoch) I've been trying to reduce the frequency of metric collection, as it has been giving me VRAM errors with my lowly NVIDIA 3080.

Any advice / solutions would be very gratefully received.

jheagerty commented 1 year ago

And sorry, for context, the checkpointing is so that I can implement self-play where the baseline model that controls the enemy is updated to match the model we're training every once in a while.

jheagerty commented 1 year ago

Never mind, figured out the main thing for me, checkpointing. You have to:

My one concern is that I doubt/cannot tell whether the learning rate scheduler is maintained, but I will worry about that later.

On less regular metrics, I will also worry about that later, but I've seen some likely jittable tools.

Howuhh commented 11 months ago

@jheagerty actually I think you can save checkpoints under jit easily with callbacks, such as jax.experimental.io_callback() (for example inside _update_step to save after the each update)

luchris429 commented 11 months ago

yep! I do it the way @Howuhh is describing. If you look at the code, there is the debug callback. You can just replace the print function with your checkpointing and wandb logging.

jheagerty commented 11 months ago

Thanks so much! Will look into this

luchris429 commented 11 months ago

The way you do it should be fine too, and is arguably better (though takes more code). The optimizer parameters (which includes the lr scheduling info) should be in the train_state

Chulabhaya commented 8 months ago

@luchris429 Hi Chris! I was wondering, how did you implement restoring of checkpoints in the PureJaxRL end-to-end jitting? I'm able to save checkpoints pretty easily with a debug callback function, but I can't quite figure out how to restore. I attempted to put a experimental.io_callback function in the train function but I can't actually do anything with the string checkpoint path because JAX can't handle strings.

luchris429 commented 8 months ago

You can try to load the runner state here!

Does it not work if you set the filename in the config?

Chulabhaya commented 8 months ago

You can try to load the runner state here!

Does it not work if you set the filename in the config?

So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):

def resuming_callback(path):
    checkpointer = ocp.PyTreeCheckpointer()
    raw_restored = checkpointer.restore(path)
    return raw_restored

runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
    raw_restored = io_callback(
        resuming_callback, runner_state, args.resume_checkpoint_path
     )

runner_state, metric = jax.lax.scan(
    _update_step, runner_state, None, args.num_iterations
)

However JAX errors out with the complaint that my args.resume_checkpoint_path is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?

luchris429 commented 7 months ago

Sorry I didn't catch this message! I hope you've figured it out.

I think you need to make sure it's a static argument since you can't JIT a string as an argument.

gzadigo commented 6 months ago

You can try to load the runner state here! Does it not work if you set the filename in the config?

So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):

def resuming_callback(path):
    checkpointer = ocp.PyTreeCheckpointer()
    raw_restored = checkpointer.restore(path)
    return raw_restored

runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
    raw_restored = io_callback(
        resuming_callback, runner_state, args.resume_checkpoint_path
     )

runner_state, metric = jax.lax.scan(
    _update_step, runner_state, None, args.num_iterations
)

However JAX errors out with the complaint that my args.resume_checkpoint_path is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?

actually, I have tried with failure result even a fixed filename. Restoration in jit is quite difficult for me