vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.54k stars 631 forks source link

PPO + JAX + EnvPool + MuJoCo #217

Open vwxyzjn opened 2 years ago

vwxyzjn commented 2 years ago

Description

Types of changes

Checklist:

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 2 years ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jan 13, 2023 at 0:46AM (UTC)
gitpod-io[bot] commented 2 years ago

vwxyzjn commented 2 years ago

It seems that there isn't that much benefit in PPO - the SPS metric is not a lot better, as shown below.

image

Note: there is probably a bug... that's why the sample efficiency suffers.

Maybe I was implementing PPO using the incorrect paradigm with JAX. Any thoughts on this @joaogui1 and @ikostrikov? Thanks!

ikostrikov commented 2 years ago

I'm not sure if

obs = obs.at[step].set(x)

is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that?

vwxyzjn commented 2 years ago

The latest commit fixes two stupid bug, we now can match the exact same performance :)

image
vwxyzjn commented 2 years ago

Using official optimizer scheduler API results in a 70% speed up commit than the manual override suggested here 🚀

https://github.com/vwxyzjn/cleanrl/blob/a0c56d3e229f1a763dcfae3e7b6ae8d9b08ed9bf/cleanrl/ppo_continuous_action_envpool_jax.py#L204-L212

image
vwxyzjn commented 2 years ago

I'm not sure if

obs = obs.at[step].set(x)

is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that?

Maybe the documentation meant if you had created an array inside the JIT the operation would be in place? I tested out

print("id(obs) before", id(obs))
obs, dones, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value(
    next_obs, next_done, obs, dones, actions, logprobs, values, step, agent_params, key
)
print("id(obs) after", id(obs))

which gives

id(obs) before 140230683526704
id(obs) after 140230683590064
ikostrikov commented 2 years ago

@vwxyzjn yes, I think it's either for arrays created inside of jit or donated arguments.

nico-bohlinger commented 2 years ago

Jitting the epochs in update_ppo() results in extremely high start up times for high epoch values and doesn't provide any speed after it's finally running. Bringing the epoch loop in the main function would fix that, like:

for _ in range(args.update_epochs):
   agent_state, loss, pg_loss, v_loss, approx_kl, key = update_ppo(agent_state, storage, key)
51616 commented 1 year ago

I think it's worth changing to lax.scan and fori_loop. Removing the for loop within rollout increases the speed quite a bit. Significantly reduces the complication time. I can make a pull request for this (and for compute_gae and update_ppo as well). I compared the original rollout and the lax.scan implementation and got the following results:

# Original for loop
Total data collection time: 135.69225978851318 seconds
Total data collection time without compilation: 98.75351285934448 seconds
Approx. compilation time: 36.93875765800476 seconds
# with lax.scan
Total data collection time: 60.91851544380188 seconds
Total data collection time without compilation: 60.029022455215454 seconds
Approx. compilation time: 0.8895087242126465 seconds

The command used is: python cleanrl/ppo_atari_envpool_xla_jax.py --env-id Breakout-v5 --total-timesteps 500000 --num-envs 32

Note: The training code was removed as the collection time correlates with the avg_episodic_length, which depends on the random exploration and training dynamics. Removing the training part makes sure that the numbers in the test only relate to the rollout function.

vwxyzjn commented 1 year ago

@51616 thanks for raising this issue. Could you share the snippet that derived these numbers?

Does lax.scan reduce the rollout time after compilation is finished? nvm I misread something. It’s interesting the rollout time after compilation is much faster, and this would be a good reason to consider using scan. Would you mind preparing the PR?

51616 commented 1 year ago

@vwxyzjn Here's the code

    def step_once(carry, step, env_step_fn):
        (agent_state, episode_stats, next_obs, next_done, storage, key, handle) = carry
        storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)
        episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)
        storage = storage.replace(rewards=storage.rewards.at[step].set(reward))
        return ((agent_state, episode_stats, next_obs, next_done, storage, key, handle), None)

    def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step,
                step_once_fn, max_steps):

        (agent_state, episode_stats, next_obs, next_done, storage, key, handle), _ = jax.lax.scan(
            step_once_fn,
            (agent_state, episode_stats, next_obs, next_done, storage, key, handle), (), max_steps)

        global_step += max_steps * args.num_envs
        return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step

    rollout_fn = partial(rollout,
                         step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed),
                         max_steps=args.num_steps)

    for update in range(1, args.num_updates + 1):
        update_time_start = time.time()
        agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step = rollout_fn(
            agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
        )
        if update == 1:
            start_time_wo_compilation = time.time()
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        print("SPS_update:", int(args.num_envs * args.num_steps / (time.time() - update_time_start)))
        writer.add_scalar(
            "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step
        )
    print("Total data collection time:", time.time() - start_time, "seconds")
    print("Total data collection time without compilation:", time.time() - start_time_wo_compilation, "seconds")
    print("Approx. compilation time:", start_time_wo_compilation - start_time, "seconds")
    envs.close()
    writer.close()

I can make a PR for this. I also think we should use the output of the lax.scan as opposed to replacing the value inplace. Might look something like this

    def step_once(carry, step, env_step_fn):
        (agent_state, episode_stats, obs, done, key, handle) = carry
        action, logprob, value, key = get_action_and_value(agent_state, obs, key)

        episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)

        storage = Storage(
            obs=obs,
            actions=action,
            logprobs=logprob,
            dones=done,
            values=value,
            rewards=reward,
            returns=jnp.zeros_like(reward),
            advantages=jnp.zeros_like(reward),
        )

        return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage)

    def rollout(agent_state, episode_stats, next_obs, next_done, key, handle,
                step_once_fn, max_steps):

        (agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan(
            step_once_fn,
            (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps)

        return agent_state, episode_stats, next_obs, next_done, key, handle, storage

    rollout_fn = partial(rollout,
                         step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed),
                         max_steps=args.num_steps)

    for update in range(1, args.num_updates + 1):
        update_time_start = time.time()
        agent_state, episode_stats, next_obs, next_done, key, handle, storage = rollout_fn(
            agent_state, episode_stats, next_obs, next_done, key, handle
        )
        if update == 1:
            start_time_wo_compilation = time.time()
        global_step += args.num_steps * args.num_envs
        ...

The code is a bit cleaner and uses the output from lax.scan directly

vwxyzjn commented 1 year ago
image
pseudo-rnd-thoughts commented 1 year ago

@vwxyzjn Was there any reason why this wasn't merged in the end?

vwxyzjn commented 1 year ago

Nothing really. If you’d like free free to take on the PR :)