hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.15k stars 723 forks source link

Tensorboard HPARAMS with DDQN #question #1128

Open Arione94 opened 3 years ago

Arione94 commented 3 years ago

Hi guys,

let me first say that I am quite new to tensorflow and particularly Tensorboard. I just started watching some videos and tutorials and I found the possibility to tune the hyper parameters using Tensorboard pretty amusing (in the HPARAMS tab). I would like to apply this methodology with DDQN, but I am not able to do it.

Starting from this piece of code, is there someone able to help me completing it to log HPARAM information to be used in the HPARAM tab of Tensorboard?

import gym

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.deepq.policies import MlpPolicy
from stable_baselines import DQN

from tensorboard.plugins.hparams import api as hp 

HP_LR = hp.HParam('learning_rate', hp.Discrete([1e-1, 1e-2, 1e-3, 1e-4, 1e-5]))
HP_GAMMA = hp.HParam('gamma', hp.Discrete([0.5, 0.7, 0.9]))

env = gym.make('CartPole-v1')

session_num = 0

for lr in HP_LR.domain.values:
    for gamma in HP_GAMMA.domain.values:

        hparams = {
            HP_LR: lr,
            HP_GAMMA: gamma,
        }
        run_name = "run-%d" % session_num

        session_num += 1
        model = DQN(MlpPolicy, env, gamma=gamma, learning_rate=lr, tensorboard_log='tensorboard/'+run_name+'/')
        model.learn(total_timesteps=25000)
        del model # remove to demonstrate saving and loading

Thank you very much for your support!

Miffyli commented 3 years ago

Hey. I have no experience with this plugin, but I assume this type of code should work. Note that we do not offer tech support for custom things like this, only bug reports for SB2. I also recommend checking out stable-baselines3, as it is more actively maintained.

Arione94 commented 3 years ago

Hey, thank you for your answer. I just would like to know if the HPARAM Tensorboard log is compatible with SB2 or if there is something already implemented to facilitate it.

Miffyli commented 3 years ago

Hmm not in SB2/SB3 itself, but you could check the zoo (and zoo3) if they have hints towards this, but likely not. They use custom config files for hyperparams.

timothe-chaumont commented 2 years ago

I have done it with a callback (I used SB3, but I think that it should work with previous versions).

import os

from tensorboard.plugins.hparams import api as hp
from stable_baselines3.common.callbacks import BaseCallback

class HyperParamsCallback(BaseCallback):
    def __init__(self, logdir: str):
        super().__init__()
        self._done = False
        self.logdir = logdir

    def _on_step(self) -> None:
        """Method called on each step, but enters the 'if' only on the first step.
        """
        # if it has already been logged, stop here
        if not self._done:
            hparams = {
                "algorithm": self.model.__class__.__name__,
                "learning rate": self.model.learning_rate,
                "gamma": self.model.gamma,
                "batch_size": self.model.batch_size,
            }

            # this line is used to save the hyperparameters in the same folder as the tensorboard logs
            logs_path = os.path.join(self.logdir, self.model.__class__.__name__ + "_0")
            # save these hyperparameters as logs
            with tf.summary.create_file_writer(logs_path).as_default():
                hp.hparams(hparams)
            self._done = True

The hparams dictionnary can also contain hp.HParam objects instead of strings.

It can then be used as any other callback :

from stable_baselines3 import A2C

logdir = "/tmp/logs/"

model = A2C('MlpPolicy', 'CartPole-v1', tensorboard_log=logdir)
model.learn(10000, tb_log_name='A2C', callback=HyperParamsCallback(logdir))

And the hyperparameters will be displayed in the HPARAMS tab of tensorboard.