tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.78k stars 720 forks source link

Does CheckPointer load previous agent, policy, and replay buffer data from saved files? #526

Open cosmir17 opened 3 years ago

cosmir17 commented 3 years ago

Hello,

There is a tutorial on Checkpointer. https://github.com/tensorflow/agents/blob/master/docs/tutorials/10_checkpointer_policysaver_tutorial.ipynb

In Checkpointer and Restore checkpoint sections, I found the following code.

train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

The last line is where it loads previously saved global_step but other than that, we don't see lines of codes to load previously saved agent, policy and replay_buffer information from checkpointer. 1) Does initialize_or_restore() code actually load these information behind the scene? So that, I can re-run my script and the training process resumes from previous step with updated q-network or policy. This is my first question. (I am aware that Reverb can be used to save Replay. My question is about CheckPointer)

The following is my code.

q_net = q_network.QNetwork(
    nimble_quest_env.observation_spec(),
    nimble_quest_env.action_spec(),
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
#global_step = tf.compat.v1.train.get_or_create_global_step()
global_step = tf.compat.v1.train.get_global_step()

#########################################################################
agent = dqn_agent.DdqnAgent(
    nimble_quest_env.time_step_spec(),
    nimble_quest_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=global_step)

agent.initialize()
##########################################################################

eval_policy = agent.policy
collect_policy = agent.collect_policy

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=nimble_quest_env.batch_size,
    max_length=replay_buffer_max_length)

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
#global_step=global_step
)
policy_dir = os.path.join(tempdir, 'policy_ddqn')

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
tf_policy_saver = policy_saver.PolicySaver(agent.policy)

I have three more questions:

  1. I couldn't have 'global_step' in common.Checkpointer(...) constructor when running second time and restoring from a checkpoint, because it complained that global_step was None. That's why I had to run global_step = tf.compat.v1.train.get_global_step() again. Am I doing right?

  2. After this, my python script resumed from the previous training step (step no. 121.. etc) but I got the message: function with signature contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.Do you think if it's ok to ignore the message? Do I still have updated q-network weight from the checkpoint files?

  3. The normal TF module have save_best_only like the following example. Can PolicySaver() do a similar thing? checkpoint = ModelCheckpoint(weightPath, monitor='loss', verbose=1, save_best_only=True, mode='min')

Thank you, I wish you a splendid day!

Sean

kbanoop commented 3 years ago

Hi Sean,

Trying to understand your question more. Is your code significantly different from our DQN example: https://github.com/tensorflow/agents/blob/master/tf_agents/agents/dqn/examples/v2/train_eval.py

This shows how to create the global step and saves in the checkpoint as well. The initialize_or_restore will restore it when you run again.

cosmir17 commented 3 years ago

Hi @kbanoop, It looks like, your example only demonstrated how to save state (initialise without restore) without restoration part. Hence, It doesn't contain global_step = tf.compat.v1.train.get_global_step().

Nonetheless, I am already aware that global_step gets successfully restored in my example. My question is about restoration of other variables as shown in

train_checkpointer = common.Checkpointer(  ckpt_dir=checkpoint_dir,    max_to_keep=1,   agent=agent,  policy=agent.policy,
    replay_buffer=replay_buffer,  global_step=global_step)

I raised this page because the agent's behaviour seemed reset to blank state when I resumed from a previously saved session. Even global_step wasn't restored without an explicit get_global_step() method call after initialize_or_restore(). So, I ask again, please. Would train_checkpointer.initialize_or_restore() restore agent, replay_buffer, policy variable data automatically? . If not, why the example(https://github.com/tensorflow/agents/blob/master/docs/tutorials/10_checkpointer_policysaver_tutorial.ipynb) I found also mentioned here contain those variables?

Would it be also ok to answer the other questions (Q 2,3,4), please?

kbanoop commented 3 years ago

Hi Sean,

Yes the train_checkpointer.initialize_or_restore() in the dqn train_eval I linked is meant to restore the agent, policy variables etc. https://github.com/tensorflow/agents/blob/5ec6898513ed31cdac6819604536a410e1a948d3/tf_agents/agents/dqn/examples/v2/train_eval.py#L220 It can be killed, and when restored, it should continue training from the same place. If that's not happening it is a bug.

Can you try running the train_eval example in a terminal instead of the checkpointer tutorial notebook? In the notebook, in case you run the same cell multiple times, may be multiple variables are created and it's harder to debug.

Re:

  1. You only need to create global_step once, before the agent and checkpoint are created. You don't have to create it again after initialize_or_restore. https://github.com/tensorflow/agents/blob/5ec6898513ed31cdac6819604536a410e1a948d3/tf_agents/agents/dqn/examples/v2/train_eval.py#L142
  2. Does this happen only in the notebook, or when you run the train_eval in a terminal as well?
  3. Unfortunately, we don't support this yet.
cosmir17 commented 3 years ago

Hi @kbanoop, thank you for your kind reply. I didn't use Notebook. Everything ran as a python project on Pycharm. So, it's same as running on terminal. The following is the actual project that I am doing. https://github.com/cosmir17/nimble-quest-tf-agent/blob/master/main_tf_agents_ddqn.py (The code can't be used for both the initial save and consequent save & load states, line 52, 53 and 83 need to be commented out or in to make it work. Currently, it's set as consequent save & load state. For example, if don't comment out line 83 in consequent save & load state, I would get the following error message ValueError: Checkpoint was expecting a trackable object (an object derived from TrackableBase), got None. If you believe this object should be trackable)

  1. If I don't use tf.compat.v1.train.get_global_step() after initialize_or_restore(), the step would be 0 when I resume(load) from a previous state. I used tf.compat.v1.train.get_or_create_global_step()first as agent = dqn_agent.DdqnAgent(...) requires step counter. So, I had to declare global_step twice. Also, the TF Agents Tutorial recommends to do so after initialize_or_restore, https://github.com/tensorflow/agents/blob/master/docs/tutorials/10_checkpointer_policysaver_tutorial.ipynb.

    For this to work, the whole set of objects should be recreated the same way as when the checkpoint was created.
    In [0]:
    train_checkpointer.initialize_or_restore()
    global_step = tf.compat.v1.train.get_global_step()

    If what you said is true, the doc seems not well documented or confusing..

  2. Everything ran as a python project on Pycharm. No Notebook.

Trafalgar98 commented 1 year ago

Hi Sean,

Were you able to fix the warning in your third point?

sabrysm commented 1 year ago

Hi @kbanoop, thank you for your kind reply. I didn't use Notebook. Everything ran as a python project on Pycharm. So, it's same as running on terminal. The following is the actual project that I am doing. https://github.com/cosmir17/nimble-quest-tf-agent/blob/master/main_tf_agents_ddqn.py (The code can't be used for both the initial save and consequent save & load states, line 52, 53 and 83 need to be commented out or in to make it work. Currently, it's set as consequent save & load state. For example, if don't comment out line 83 in consequent save & load state, I would get the following error message ValueError: Checkpoint was expecting a trackable object (an object derived from TrackableBase), got None. If you believe this object should be trackable)

  1. If I don't use tf.compat.v1.train.get_global_step() after initialize_or_restore(), the step would be 0 when I resume(load) from a previous state. I used tf.compat.v1.train.get_or_create_global_step()first as agent = dqn_agent.DdqnAgent(...) requires step counter. So, I had to declare global_step twice. Also, the TF Agents Tutorial recommends to do so after initialize_or_restore, https://github.com/tensorflow/agents/blob/master/docs/tutorials/10_checkpointer_policysaver_tutorial.ipynb.
For this to work, the whole set of objects should be recreated the same way as when the checkpoint was created.
In [0]:
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

If what you said is true, the doc seems not well documented or confusing..

  1. Everything ran as a python project on Pycharm. No Notebook.

I agree the doc is confusing