hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

[Bug] Callbacks can't access local variables #760

Closed m-rph closed 4 years ago

m-rph commented 4 years ago

Description From the discussion here #756, callbacks in DQN should have access to local variables, e.g. done through BaseCallback.locals / self.locals. However, this isn't the case.

Code example

from stable_baselines.common.callbacks import BaseCallback
import gym
import stable_baselines as sb
env = gym.make("MountainCar-v0")
class TestCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
    def _on_step(self):
        assert ('done' in self.locals)

agent = sb.DQN('MlpPolicy', env, verbose=1)
agent.learn(100, callback=TestCallback())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/stelios/anaconda3/envs/thesis37/lib/python3.7/site-packages/stable_baselines/deepq/dqn.py", line 224, in learn
    if callback.on_step() is False:
  File "/Users/stelios/anaconda3/envs/thesis37/lib/python3.7/site-packages/stable_baselines/common/callbacks.py", line 89, in on_step
    return self._on_step()
  File "<stdin>", line 5, in _on_step
AssertionError

Workaround

In DQN.learn, line 219, the following results in assertion error.

new_obs, rew, done, info = self.env.step(env_action)

self.num_timesteps += 1
assert 'done' in callback.locals  # fails
# Stop training if return value is False
if callback.on_step() is False:
    break

If we call locals(), then the assertion succeeds.

new_obs, rew, done, info = self.env.step(env_action)

self.num_timesteps += 1
locals()
assert 'done' in callback.locals  # succeeds
# Stop training if return value is False
if callback.on_step() is False:
    break
m-rph commented 4 years ago

I have solved this by initializing the variables before the callback.on_training_start call.

done = False
callback.on_training_start(locals(), globals())

Should I make a PR with this change? Which other variables should I initialise?

Miffyli commented 4 years ago

I was just about to test this, but you got ahead of me :thumbsup:.

I think it would be nice to have same treatment for other algorithms as well, where locals has access to things like done and reward (can be quite relevant for on_step callback). This could include all env.step related variables (obs, reward, done, info and chosen action). I am not 100% sure if this can be done with all algorithms, but I can see use-cases for this kind of information.

@araffin Quick thoughts on this, with you being the main author of this code?

araffin commented 4 years ago

Quick thoughts on this

I'm not sure, I would rather do something like callback.update_locals(locals()) after each step.

m-rph commented 4 years ago

I'm not sure, I would rather do something like callback.update_locals(locals()) after each step.

It's not necessary to introduce a function in the callbacks, apparently, the locals() returns a singleton for that specific function that is updated on every locals() call. In DQN:

# Stop training if return value is False
locals() # update the locals here
if callback.on_step() is False:
    break

Test:

from stable_baselines.common.callbacks import BaseCallback
import gym
import stable_baselines as sb
env = gym.make("MountainCar-v0")
class TestCallback(BaseCallback):
    def _on_step(self):
        assert 'done' in self.locals
    def _on_training_start(self):
        assert 'done' not in self.locals

agent = sb.DQN('MlpPolicy', env, verbose=1)
agent.learn(100, callback=TestCallback())

Notes:

Locals at the time _on_training_start is called, contains:

{'self': <stable_baselines.deepq.dqn.DQN object at 0x152c6ead0>, 'total_timesteps': 100, 'callback': <__main__.TestCallback object at 0x152c6eb50>, 'log_interval': 100, 'tb_log_name': 'DQN', 'reset_num_timesteps': True, 'replay_wrapper': None, 'new_tb_log': True, 'writer': None, 'exploration_annealing_steps': 10, 'episode_rewards': [0.0], 'episode_successes': []}

On the first _on_step:

{'self': <stable_baselines.deepq.dqn.DQN object at 0x152c6ead0>, 'total_timesteps': 100, 'callback': <__main__.TestCallback object at 0x152c6eb50>, 'log_interval': 100, 'tb_log_name': 'DQN', 'reset_num_timesteps': True, 'replay_wrapper': None, 'new_tb_log': True, 'writer': None, 'exploration_annealing_steps': 10, 'episode_rewards': [0.0], 'episode_successes': [], 'reset': False, 'obs': array([-0.4033148,  0.       ]), '_': 0, 'kwargs': {}, 'update_eps': 1.0, 'update_param_noise_threshold': 0.0, 'action': 0, 'env_action': 0, 'new_obs': array([-0.40519748, -0.00188268]), 'rew': -1.0, 'done': False, 'info': {}}

On the second _on_step:

{'self': <stable_baselines.deepq.dqn.DQN object at 0x152c6ead0>, 'total_timesteps': 100, 'callback': <__main__.TestCallback object at 0x152c6eb50>, 'log_interval': 100, 'tb_log_name': 'DQN', 'reset_num_timesteps': True, 'replay_wrapper': None, 'new_tb_log': True, 'writer': None, 'exploration_annealing_steps': 10, 'episode_rewards': [-1.0], 'episode_successes': [], 'reset': False, 'obs': array([-0.40519748, -0.00188268]), '_': 1, 'kwargs': {}, 'update_eps': 0.902, 'update_param_noise_threshold': 0.0, 'action': 1, 'env_action': 1, 'new_obs': array([-0.40794961, -0.00275213]), 'rew': -1.0, 'done': False, 'info': {}, 'obs_': array([-0.4033148,  0.       ]), 'new_obs_': array([-0.40519748, -0.00188268]), 'reward_': -1.0, 'can_sample': False, 'mean_100ep_reward': -inf, 'num_episodes': 1}

So depending on the way this is implemented, it should be noted in the documentation whether each object is available in all steps, or from the second and on.

Miffyli commented 4 years ago

So depending on the way this is implemented, it should be noted in the documentation whether each object is available in all steps, or from the second and on.

I agree at the very least this should be done, so people know that the information will be available on second step, and it is just the first step(s) that break the code.

m-rph commented 4 years ago

@araffin, how would you like this to be dealt with?

araffin commented 4 years ago

I would prefer an explicit call .update_locals(locals()) rather than a lonely call to locals() if this is the question?

And also an update to the documentation anyway.

m-rph commented 4 years ago

Okay, I will get on with the PR then.