vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
4.84k stars 560 forks source link

Truncation not handled correctly when `optimize_memory_usage=True` #460

Open samlobel opened 2 months ago

samlobel commented 2 months ago

Problem Description

In, for example, dqn_atari.py the replay buffer is instantiated with the optimize_memory_usage=True flag. This makes the buffer only have one stored list for observations, and chooses next_obs=observations[i+1] when sampling. However, cleanrl does its own logic to handle this (if trunc: real_next_obs[idx] = infos["final_observation"][idx]). But optimize_memory_usage means that this change is not reflected in the stored/sampled data.

Checklist

Current Behavior

Instead of data.next_observation[i] being the correct next observation, when an episode is truncated the next observation is the first of the reset environment.

Expected Behavior

It should be the correct next observation.

Possible Solution

I'm guessing there's a way to make this work, but for now the easiest thing to do is set optimize_memory_usage to False.

Steps to Reproduce

Here's a minimal code example, where the important parts are directly cribbed from dqn_atari.py. Switching to optimize_memory_usage=False prevents the assertion error.

import gymnasium as gym

from stable_baselines3.common.buffers import ReplayBuffer
import stable_baselines3 as sb3
import numpy as np

def make_env(env_id, seed, idx, capture_video, run_name):    
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env

    return thunk

envs = gym.vector.SyncVectorEnv(
    [make_env("MountainCar-v0", i, i, False, "testing") for i in [0]]
)

obs, _ = envs.reset(seed=0)

rb = ReplayBuffer(
    1000,
    envs.single_observation_space,
    envs.single_action_space,
    "cpu",
    optimize_memory_usage=True,
    # optimize_memory_usage=False,
    handle_timeout_termination=False,
)

seen_obs_and_next = set()
for i in range(1000):
    actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    next_obs, rewards, terminations, truncations, infos = envs.step(actions)
    # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
    real_next_obs = next_obs.copy()
    for idx, trunc in enumerate(truncations):
        if trunc:
            real_next_obs[idx] = infos["final_observation"][idx]

    rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
    for o, next_o in zip(obs, real_next_obs): # because vectorized env
        seen_obs_and_next.add( (tuple(o.tolist()), tuple(next_o.tolist())) )

data = rb.sample(10000)
for i in range(10000):
    o = data.observations[i]
    no = data.next_observations[i]
    assert (tuple(o.tolist()), tuple(no.tolist())) in seen_obs_and_next