Farama-Foundation / Gymnasium

An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
https://gymnasium.farama.org
MIT License
6.7k stars 752 forks source link

[Question] Manually reset vector envronment #831

Open zzhixin opened 9 months ago

zzhixin commented 9 months ago

Question

As far as I know, the gym vector environment auto-reset a subenv when the env is done. I wonder If there is a way to manually reset it. Because I want to exploiting vecenv feature in implementing vanilla policy gradient algorithm, where every update's data should be one or severy complete episode.

Kallinteris-Andreas commented 9 months ago

https://gymnasium.farama.org/api/vector/#gymnasium.vector.VectorEnv.reset

zzhixin commented 9 months ago

@Kallinteris-Andreas Thanks for the reply. But that is not satisifying. I am sorry if I didn't explain it clearly. My episode length is not a constant. What I want it that the terminated(or truncated) subenv not auto reset but the others continues until all the sub env is done. So I can collect complete one-episodic data. Even more, Say the batch size is around 256, which means I want at least 256 step data. I hope vecenv, say 8 parallel, can collect the smallest n complete episodic data whose step is greater than 256. How can I achieve this? Does current version of gym support this function?

Kallinteris-Andreas commented 9 months ago

The vector env does not support barrier synchronization functionality (nor do I believe it should)

you can you simply serially run multiple environments

pseudo-rnd-thoughts commented 9 months ago

This is a feature we are considering adding but haven't done yet. The developers of PyTorch RL have also requested this feature. The problem is what the sub environments return data is after termination / truncation. My initial thought is termination = True for subsequent timesteps such that you can do np.all(terminated | truncation) to see if all environments have ended. But I fear that this will cause issues for record episode statistics each wrappers

jamartinh commented 8 months ago

Hi @zzhixin , @Kallinteris-Andreas, the are some implementations of vector envs such as the Tianshou implementation that incorporate to every Env function like: reset, step etc an id parameter indicating to which of the environments the function should be called.

@pseudo-rnd-thoughts perhaps a feature for Gymnasium?

See the base implemenation: https://github.com/thu-ml/tianshou/blob/master/tianshou/env/venvs.py

https://github.com/thu-ml/tianshou/blob/294145aa3d0aca56be86c1ed3677ee14e23fda4e/tianshou/env/venvs.py#L235-L245

So in every call to reset or step you can specify which environments to execute.

You can track then which envs have returned terminated or truncated and not call them any more and do something like this or better if you find how (but please update me on this because I need it too ;D)

    env = vec_env
    n_envs = env.env_num
    s, info = env.reset(options=options)
    dones = np.zeros(n_envs, dtype=bool)        
    ids = np.arange(n_envs)

    for step in range(10000):

        a = get_action(s)
        s, _, _dones, _trunc, _ = env.step(a, id=ids)
        dones[ids] = _dones.astype(bool)  # update dones
        ids = np.flatnonzero(dones == 0) # update ids
        s = np.compress(_dones == 0, s, axis=0) # update state

        if np.all(dones):
            break   
zzhixin commented 8 months ago

@pseudo-rnd-thoughts Hi I have implemented a workaround. Put it short, If you have 16 vecenv and know the average episode length, say 20, then

  1. Choose a proper larger batch size N that N > 16*20*H, where H > 5. That means around H episodes are collected from one env every time.
  2. Collect one update data.
  3. Truncate ongoing episodes data of every env.
  4. envs.reset(). Back to step 2.

The reason of H > 5 is that if H is bigger enough, it's no big loss to dump the ongoing data.

zzhixin commented 8 months ago

The implementation is not complicated and fairly decompled. Say you have collect a batch of data with shape of (num_steps, num_env, single_data_len):

obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

# Collect data using vecenv...

Then you can implement following code to truncate the ongoing data (and also flatten it).

if args.complete_eps:
    last_done_t = [args.num_steps for _ in range(args.num_envs)] # last done step every envs
    for e in range(args.num_envs):
        for t in reversed(range(args.num_steps)):
            if dones[t, e] == 1: break
        assert t > 0, "All data is truncated. Maybe args.num_envs is too small"
        last_done_t[e] = t

    def truncated_stack(data, end):
        data_env = []
        for e in range(args.num_envs):
            data_env.append(data[:end[e]+1,e])
        b_batch = torch.cat(data_env)
        return b_batch

    # flatten the batch
    b_obs = truncated_stack(obs, last_done_t)
    b_logprobs = truncated_stack(logprobs, last_done_t)
    b_actions = truncated_stack(actions, last_done_t)
    b_dones = truncated_stack(dones, last_done_t)
    b_advantages = truncated_stack(advantages, last_done_t)
    b_returns = truncated_stack(returns, last_done_t)
    b_values = truncated_stack(values, last_done_t)

    real_batch_size = b_obs.shape[0]
RogerJL commented 8 months ago

I do something like this

for p_, info, trajectories in Game.play_tournament(_a_gym_env, agents, record_trajectories=True, games_to_play=BATCH_GAMES, seed=seed):

where

`play_tournament(...)


 p_observation, _info = a_gym_env.reset(seed=seed)
 - - -
 while remaining_games > 0:
        p_agent_output = p_action_logs
        p_action = jnp.array(list(map(os_action_v2, p_agent_output)))  # TODO: jnp.lax.map

        # Observe agent_output
        p_observation, p_reward, p_terminal, p_truncated, p_info = a_gym_env.step(p_action)

        if record_trajectories:
            for p_, player_ in enumerate(p_current_player):
                agent_output = p_agent_output[p_]
                agent_output = os_reward_replace_v2(agent_output, p_reward, p_)
                trajectories[p_][player_].append(agent_output)
        p_done = jnp.logical_or(p_terminal, p_truncated)  # BUG in gymnasium => double reset
        del p_agent_output, p_reward, p_terminal, p_truncated  # only p_observation and p_info lives on

        p_done = jnp.logical_and(p_done, a_gym_env.autoreset_envs)
        for p_, done_ in zip(itertools.count(), jnp.logical_and(p_done, jnp.logical_not(p_done_prev))):
            if done_ and remaining_games > 0:
                remaining_games -= 1
                yield p_, (Game.rebuild_info(p_info, p_)), trajectories[p_] if record_trajectories else None
                trajectories[p_] = tuple(list() for _ in range(GameState.MAX_PLAYERS))
        p_done_prev = p_done

`

im-Kitsch commented 6 months ago

I think a naive solution is to consider the problem as queue and use Processpool, for example, if you would like to collect 100 trajectories, you could define collect_trajectory() function and parallel run it 100 times, but just set maximum_worker as a limited batch size.

Another idea may envpool asynchronize mode https://github.com/sail-sg/envpool?tab=readme-ov-file#asynchronous-api . You could just run like this:

# when n_trajecotries or n_transitions > threshold
..., env_ids = env.recv() 
not_done_idx = select_not_done_idx(env_ids) 
act = actor(obs, not_done_idx) # only return actions that environment is not terminated
run env.step(act, not_done_idx). 

Finally no environement will be interracted and you could cut the collection