DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.14k stars 1.7k forks source link

Logger information #1998

Closed XiaobenLi00 closed 1 month ago

XiaobenLi00 commented 2 months ago

🐛 Bug

When log info to tensorboard, self.logger.dump(step=self.num_timesteps) is called after self.logger.record

self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
    self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
    self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if len(self.ep_success_buffer) > 0:
    self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
self.logger.dump(step=self.num_timesteps)

but in the plots of wandb, the x-axis is just the num of steps instead of self.num_timesteps

image

Do you have any suggestions?

BTW, it seems that the info is just logged each rollout, thus the time interval is too large, so how to get denser log info?

Looking forward to your suggestions!

Code example

""" =================================================
# Copyright (c) Facebook, Inc. and its affiliates
Authors  :: Cameron Berg (cameronberg@fb.com), Vikash Kumar (vikashplus@gmail.com), Vittorio Caggiano (caggiano@gmail.com)
================================================= """

"""
This is a job script for running SB3 on myosuite tasks.
"""
from typing import Any, Dict

import os
import json
import time as timer
from stable_baselines3 import PPO, SAC
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecNormalize, SubprocVecEnv
import torch
import gymnasium as gym
# import torch as th
import numpy as np
from omegaconf import OmegaConf
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Video

os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
os.environ['MUJOCO_GL'] = 'osmesa'

import functools
from in_callbacks import InfoCallback, FallbackCheckpoint, SaveSuccesses, EvalCallback

IS_WnB_enabled = False
try:
    import wandb
    from wandb.integration.sb3 import WandbCallback
    IS_WnB_enabled = True
except ImportError as e:
    pass 

class VideoRecorderCallback(BaseCallback):
    def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 5, deterministic: bool = True):
        """
        Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard

        :param eval_env: A gym environment from which the trajectory is recorded
        :param render_freq: Render the agent's trajectory every eval_freq call of the callback.
        :param n_eval_episodes: Number of episodes to render
        :param deterministic: Whether to use deterministic or stochastic policy
        """
        super().__init__()
        self._eval_env = eval_env
        self._render_freq = render_freq
        self._n_eval_episodes = n_eval_episodes
        self._deterministic = deterministic

    def _on_step(self) -> bool:
        if self.n_calls % self._render_freq == 0:
            screens = []

            def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                """
                Renders the environment in its current state, recording the screen in the captured `screens` list

                :param _locals: A dictionary containing all local variables of the callback's scope
                :param _globals: A dictionary containing all global variables of the callback's scope
                """
                # We expect `render()` to return a uint8 array with values in [0, 255] or a float array
                # with values in [0, 1], as described in
                # https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
                screen = self._eval_env.render(mode="rgb_array")
                # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
                screens.append(screen.transpose(2, 0, 1))

            evaluate_policy(
                self.model,
                self._eval_env,
                callback=grab_screens,
                n_eval_episodes=self._n_eval_episodes,
                deterministic=self._deterministic,
            )
            self.logger.record(
                "trajectory/video",
                Video(torch.from_numpy(np.asarray([screens])), fps=40),
                exclude=("stdout", "log", "json", "csv"),
            )
        return True

def train_loop(job_data) -> None:

    config = {
            "policy_type": job_data.policy,
            "total_timesteps": job_data.total_timesteps,
            "env_name": job_data.env,
    }
    if IS_WnB_enabled:
        run = wandb.init(
            project=job_data.project,
            config=config,
            sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
            monitor_gym=True,  # auto-upload the videos of agents playing the game
            save_code=True,  # optional
        )

    logger = configure(f'results_{job_data.env}', ["stdout", "tensorboard", "log"])
    # Create the vectorized environment and normalize ob
    # env = make_vec_env(job_data.env, n_envs=job_data.n_env)
    env = make_vec_env(job_data.env, n_envs=job_data.n_env, vec_env_cls=SubprocVecEnv)
    env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=10.)

    eval_env = make_vec_env(job_data.env, n_envs=job_data.n_eval_env, vec_env_cls=SubprocVecEnv)
    # eval_env = make_vec_env(job_data.env, n_envs=job_data.n_eval_env)
    eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, clip_obs=10.)

    algo = job_data.algorithm
    if algo == 'PPO':
        # Load activation function from config
        policy_kwargs = OmegaConf.to_container(job_data.policy_kwargs, resolve=True)

        model = PPO(job_data.policy, env,  verbose=1,
                    tensorboard_log=f"wandb/{run.id}",
                    learning_rate=job_data.learning_rate, 
                    batch_size=job_data.batch_size, 
                    policy_kwargs=policy_kwargs,
                    gamma=job_data.gamma, **job_data.alg_hyper_params)
    elif algo == 'SAC':
        model = SAC(job_data.policy, env, 
                    learning_rate=job_data.learning_rate, 
                    buffer_size=job_data.buffer_size, 
                    learning_starts=job_data.learning_starts, 
                    batch_size=job_data.batch_size, 
                    tau=job_data.tau, 
                    gamma=job_data.gamma, **job_data.alg_hyper_params)

    if job_data.job_name =="checkpoint.pt":
        foldername = os.path.join(os.path.dirname(os.path.realpath(__file__)), f"baseline_SB3/myoChal24/{job_data.env}")
        file_path = os.path.join(foldername, job_data.job_name)
        if os.path.isfile(file_path):
            print("Loading weights from checkpoint")
            model.policy.load_state_dict(torch.load(file_path))
        else:
            raise FileNotFoundError(f"No file found at the specified path: {file_path}. See https://github.com/MyoHub/myosuite/blob/dev/myosuite/agents/README.md to download one.")
    else:
        print("No checkpoint loaded, training starts.")

    if IS_WnB_enabled:
        callback = [WandbCallback(
                model_save_path=f"models/{run.id}",
                verbose=2,
            )]
    else:
        callback = []

    # callback += [EvalCallback(max(job_data.eval_freq // job_data.n_env, 1), eval_env)]
    callback += [InfoCallback()]
    callback += [FallbackCheckpoint(max(job_data.restore_checkpoint_freq // job_data.n_env, 1))]
    callback += [CheckpointCallback(save_freq=max(job_data.save_freq // job_data.n_env, 1), save_path=f'logs/',
                                            name_prefix='rl_models')]
    # callback += [VideoRecorderCallback(eval_env, render_freq=max(job_data.render_freq // job_data.n_env, 1))]

    model.set_logger(logger)
    model.learn(
        total_timesteps=config["total_timesteps"],
        callback=callback,
    )

    model.save(f"{job_data.env}_"+algo+"_model")
    env.save(f'{job_data.env}_'+algo+'_env')

    if IS_WnB_enabled:
        run.finish()

Relevant log output / Error message

No response

System Info

- OS: Linux-5.15.0-67-generic-x86_64-with-glibc2.10 # 74~20.04.1-Ubuntu SMP Wed Feb 22 14:52:34 UTC 2023
- Python: 3.8.19
- Stable-Baselines3: 2.3.2
- PyTorch: 2.4.0+cu121
- GPU Enabled: True
- Numpy: 1.24.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
- OpenAI Gym: 0.26.2

Checklist

araffin commented 2 months ago

but in the plots of wandb, the x-axis is just the num of steps instead of self.num_timesteps

Please have a look at https://www.youtube.com/watch?v=ed1bqaZGOQw

You can change that in the "x axis" selector, you need to select "global_step".