eleurent / rl-agents

Implementations of Reinforcement Learning and Planning algorithms
MIT License
553 stars 149 forks source link

How do I save a trained model to reuse it? #69

Open shivambhandari99 opened 3 years ago

shivambhandari99 commented 3 years ago

I specifically want to save the DQN network i've been working on. I went through the documentation and can't find anything. I tried using pickle for my agent after training but that doesn't seem to work either. I'd appreciate any input.

eleurent commented 3 years ago

During training, the agent model is saved through the save_agent_model function: https://github.com/eleurent/rl-agents/blob/a290be38351cf29c03779cb6683d831a06b74864/rl_agents/trainer/evaluation.py#L276

There are two cases where this function is used to automatically save the model in after_some_episodes: https://github.com/eleurent/rl-agents/blob/a290be38351cf29c03779cb6683d831a06b74864/rl_agents/trainer/evaluation.py#L318

  1. When the "episode is selected", which is decided by openai gym.wrappers.Monitor.is_episode_selected method. By default, gym uses the capped_cubic_video_schedule, which is a cubic progression for early episodes (1,8,27,...) and then every 1000 episodes (1000, 2000, 3000...) (This can be changed by modifying the video_callable argument of the Monitor).

https://github.com/openai/gym/blob/a5a6ae6bc0a5cfc0ff1ce9be723d59593c165022/gym/wrappers/monitor.py#L254

  1. When a new best performance is reached (averaged over a window of 50 episodes).

The resulting models are saved as .tar files, by default both in the run directory and in an out/env/agent/saved_models directory.

Then, in order to load a save model, you just have to use the --recover flag to load the latest model from the out/env/agent/saved_models directory, or --recover-from=<path> to pick a specific model.

shivambhandari99 commented 3 years ago

Thank you for your reply. What should I do to recover the agent to be used outside of evaluation or experiments? For example, I want to run this snippet locally:

for time in range(100):
    if(time==0):
        action = 1
    else:
        action = agent.act(obs)
    #print(DiscreteMetaAction.ACTIONS_ALL[action])
    obs, reward, done, info = env.step(action)
    env.render()
    # plt.imshow(env.render(mode="rgb_array"))
    # plt.show()
    if done:
        print(time)
        break
eleurent commented 3 years ago

Oh, I see! Then you can just use agent.save(model_path) and agent.load(model_path)