michaelnny / deep_rl_zoo

A collection of Deep Reinforcement Learning algorithms implemented with PyTorch to solve Atari games and classic control tasks like CartPole, LunarLander, and MountainCar.
Apache License 2.0
104 stars 11 forks source link

Does training resume from last saved checkpoint? #21

Open KishoreP1 opened 8 months ago

KishoreP1 commented 8 months ago

When training is interrupted and later resumed, I expect the process to restart from the last saved checkpoint iteration. However, even when specifying the same --checkpoint_dir flag, the training process restarts from iteration 0, disregarding previously completed iterations.

I tried:

  1. Start training with a specified --checkpoint_dir.
  2. Allow the training to proceed past a few iterations (e.g., 12 iterations).
  3. Interrupt the training process.
  4. Resume training with the same --checkpoint_dir flag.

I expected the training to resume from iteration 13, considering the last completed iteration was 12. However, the training restarts from iteration 1, ignoring the checkpoints saved in the specified directory.

Inside run_learner of main_loop.py, the checkpointing and iteration logging logic seems correct. However, I cannot find where the code loads the checkpoint to resume training from the last saved state.

# Start training
for iteration in range(1, num_iterations + 1):
    logging.info(f'Training iteration {iteration}')
    logging.info(f'Starting {learner.agent_name} ...')

    # Update shared iteration count.
    iteration_count.value = iteration

    # Set start training event.
    start_iteration_event.set()
    learner.reset()

    run_learner_loop(learner, data_queue, num_actors, learner_trackers)

    start_iteration_event.clear()
    checkpoint.set_iteration(iteration)
    saved_ckpt = checkpoint.save()

    if saved_ckpt:
        logging.info(f'New checkpoint created at "{saved_ckpt}"')
michaelnny commented 8 months ago

Hi, currently the training scripts does not support resume training. As you can see from the code, the --checkpoint_dir argument just specify the path to save model checkpoints, it will not looking for some existing checkpoint to continue training.

You should be able to adapt the code to add the logic to looking for latest model checkpoint if required, here's an example of manually loading checkpoint file in the eval_agent.py module.

    if FLAGS.load_checkpoint_file:
        checkpoint.restore(FLAGS.load_checkpoint_file)

However, keep in mind, the code will only save model state, not optimizer or the agent internal states (number of updates etc.), and also need to correctly handle logging to tensorboard or the csv files.