DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.35k stars 1.6k forks source link

[Question] How can I log Q-values of DQN using custom callback from stable baselines 3 in Tensorboard? #1948

Closed gglsmm closed 1 week ago

gglsmm commented 2 weeks ago

❓ Question

Hello,

I am trying to log Q-values using custom callback, but I am new in this field and not sure the code below is the correct way to do it.

class CustomLoggingCallback(BaseCallback):
    def __init__(self, verbose=1):
        super(CustomLoggingCallback, self).__init__(verbose)

    def _on_step(self) -> bool:
        # Log Q-values
        obs = th.tensor(self.locals["replay_buffer"].observations[-1], device=self.model.device).float()
        q_values = self.model.q_net(obs)
        avg_q_values = q_values.mean().item()
        self.logger.record('q_values/mean', avg_q_values)

        return True

Checklist

araffin commented 2 weeks ago

Probably a duplicate of https://github.com/DLR-RM/stable-baselines3/issues/568 and https://github.com/DLR-RM/stable-baselines3/issues/308

gglsmm commented 2 weeks ago

Thank you for your response. in #568, I assume it is for after training.

My question is I want to log Q-values during the training process. For example, we observe loss or episode reward mean during the training on Tensorboard and I want to see Q-values if possible. See attached it is for Pong game. I am just not sure if this code is right.

Q-values

araffin commented 1 week ago

I assume it is for after training.

why?

See attached it is for Pong game. I am just not sure if this code is right.

it looks fine, although you should use the code example I linked and you might want to call logger.dump() if the values are not logged often enough for you.

gglsmm commented 1 week ago

I assume it is for after training.

why?

I collect Q-values during training, so my callback is in model.learn(), but in this code #568, I assumed it collects Q-values after loading a trained model, if not it is my misunderstanding.

Thank you for your feedback.