Closed m-rph closed 4 years ago
I have solved this by initializing the variables before the callback.on_training_start
call.
done = False
callback.on_training_start(locals(), globals())
Should I make a PR with this change? Which other variables should I initialise?
I was just about to test this, but you got ahead of me :thumbsup:.
I think it would be nice to have same treatment for other algorithms as well, where locals
has access to things like done
and reward
(can be quite relevant for on_step
callback). This could include all env.step
related variables (obs, reward, done, info and chosen action). I am not 100% sure if this can be done with all algorithms, but I can see use-cases for this kind of information.
@araffin Quick thoughts on this, with you being the main author of this code?
Quick thoughts on this
I'm not sure, I would rather do something like callback.update_locals(locals())
after each step.
I'm not sure, I would rather do something like
callback.update_locals(locals())
after each step.
It's not necessary to introduce a function in the callbacks, apparently, the locals()
returns a singleton for that specific function that is updated on every locals()
call.
In DQN:
# Stop training if return value is False
locals() # update the locals here
if callback.on_step() is False:
break
Test:
from stable_baselines.common.callbacks import BaseCallback
import gym
import stable_baselines as sb
env = gym.make("MountainCar-v0")
class TestCallback(BaseCallback):
def _on_step(self):
assert 'done' in self.locals
def _on_training_start(self):
assert 'done' not in self.locals
agent = sb.DQN('MlpPolicy', env, verbose=1)
agent.learn(100, callback=TestCallback())
Notes:
Locals at the time _on_training_start
is called, contains:
{'self': <stable_baselines.deepq.dqn.DQN object at 0x152c6ead0>, 'total_timesteps': 100, 'callback': <__main__.TestCallback object at 0x152c6eb50>, 'log_interval': 100, 'tb_log_name': 'DQN', 'reset_num_timesteps': True, 'replay_wrapper': None, 'new_tb_log': True, 'writer': None, 'exploration_annealing_steps': 10, 'episode_rewards': [0.0], 'episode_successes': []}
On the first _on_step
:
{'self': <stable_baselines.deepq.dqn.DQN object at 0x152c6ead0>, 'total_timesteps': 100, 'callback': <__main__.TestCallback object at 0x152c6eb50>, 'log_interval': 100, 'tb_log_name': 'DQN', 'reset_num_timesteps': True, 'replay_wrapper': None, 'new_tb_log': True, 'writer': None, 'exploration_annealing_steps': 10, 'episode_rewards': [0.0], 'episode_successes': [], 'reset': False, 'obs': array([-0.4033148, 0. ]), '_': 0, 'kwargs': {}, 'update_eps': 1.0, 'update_param_noise_threshold': 0.0, 'action': 0, 'env_action': 0, 'new_obs': array([-0.40519748, -0.00188268]), 'rew': -1.0, 'done': False, 'info': {}}
On the second _on_step
:
{'self': <stable_baselines.deepq.dqn.DQN object at 0x152c6ead0>, 'total_timesteps': 100, 'callback': <__main__.TestCallback object at 0x152c6eb50>, 'log_interval': 100, 'tb_log_name': 'DQN', 'reset_num_timesteps': True, 'replay_wrapper': None, 'new_tb_log': True, 'writer': None, 'exploration_annealing_steps': 10, 'episode_rewards': [-1.0], 'episode_successes': [], 'reset': False, 'obs': array([-0.40519748, -0.00188268]), '_': 1, 'kwargs': {}, 'update_eps': 0.902, 'update_param_noise_threshold': 0.0, 'action': 1, 'env_action': 1, 'new_obs': array([-0.40794961, -0.00275213]), 'rew': -1.0, 'done': False, 'info': {}, 'obs_': array([-0.4033148, 0. ]), 'new_obs_': array([-0.40519748, -0.00188268]), 'reward_': -1.0, 'can_sample': False, 'mean_100ep_reward': -inf, 'num_episodes': 1}
So depending on the way this is implemented, it should be noted in the documentation whether each object is available in all steps, or from the second and on.
So depending on the way this is implemented, it should be noted in the documentation whether each object is available in all steps, or from the second and on.
I agree at the very least this should be done, so people know that the information will be available on second step, and it is just the first step(s) that break the code.
@araffin, how would you like this to be dealt with?
I would prefer an explicit call .update_locals(locals())
rather than a lonely call to locals()
if this is the question?
And also an update to the documentation anyway.
Okay, I will get on with the PR then.
Description From the discussion here #756, callbacks in DQN should have access to local variables, e.g.
done
throughBaseCallback.locals
/self.locals
. However, this isn't the case.Code example
Workaround
In
DQN.learn
, line 219, the following results in assertion error.If we call
locals()
, then the assertion succeeds.