Open Arione94 opened 3 years ago
Hey. I have no experience with this plugin, but I assume this type of code should work. Note that we do not offer tech support for custom things like this, only bug reports for SB2. I also recommend checking out stable-baselines3, as it is more actively maintained.
Hey, thank you for your answer. I just would like to know if the HPARAM Tensorboard log is compatible with SB2 or if there is something already implemented to facilitate it.
I have done it with a callback (I used SB3, but I think that it should work with previous versions).
import os
from tensorboard.plugins.hparams import api as hp
from stable_baselines3.common.callbacks import BaseCallback
class HyperParamsCallback(BaseCallback):
def __init__(self, logdir: str):
super().__init__()
self._done = False
self.logdir = logdir
def _on_step(self) -> None:
"""Method called on each step, but enters the 'if' only on the first step.
"""
# if it has already been logged, stop here
if not self._done:
hparams = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
"batch_size": self.model.batch_size,
}
# this line is used to save the hyperparameters in the same folder as the tensorboard logs
logs_path = os.path.join(self.logdir, self.model.__class__.__name__ + "_0")
# save these hyperparameters as logs
with tf.summary.create_file_writer(logs_path).as_default():
hp.hparams(hparams)
self._done = True
The hparams
dictionnary can also contain hp.HParam
objects instead of strings.
It can then be used as any other callback :
from stable_baselines3 import A2C
logdir = "/tmp/logs/"
model = A2C('MlpPolicy', 'CartPole-v1', tensorboard_log=logdir)
model.learn(10000, tb_log_name='A2C', callback=HyperParamsCallback(logdir))
And the hyperparameters will be displayed in the HPARAMS tab of tensorboard.
Hi guys,
let me first say that I am quite new to tensorflow and particularly Tensorboard. I just started watching some videos and tutorials and I found the possibility to tune the hyper parameters using Tensorboard pretty amusing (in the HPARAMS tab). I would like to apply this methodology with DDQN, but I am not able to do it.
Starting from this piece of code, is there someone able to help me completing it to log HPARAM information to be used in the HPARAM tab of Tensorboard?
Thank you very much for your support!