Open vwxyzjn opened 2 years ago
The latest updates on your projects. Learn more about Vercel for Git ↗︎
Name | Status | Preview | Updated |
---|---|---|---|
cleanrl | ✅ Ready (Inspect) | Visit Preview | Jan 13, 2023 at 0:46AM (UTC) |
It seems that there isn't that much benefit in PPO - the SPS metric is not a lot better, as shown below.
Note: there is probably a bug... that's why the sample efficiency suffers.
Maybe I was implementing PPO using the incorrect paradigm with JAX. Any thoughts on this @joaogui1 and @ikostrikov? Thanks!
I'm not sure if
obs = obs.at[step].set(x)
is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that?
The latest commit fixes two stupid bug, we now can match the exact same performance :)
I'm not sure if
obs = obs.at[step].set(x)
is indeed in-place inside of jit. I think in this specific case it still creates a new array. I think it's truly in-place only for specific use cases. For example, when memory is donated (on TPU and GPU only). Could you double check that?
Maybe the documentation meant if you had created an array inside the JIT the operation would be in place? I tested out
print("id(obs) before", id(obs))
obs, dones, actions, logprobs, values, action, logprob, entropy, value, key = get_action_and_value(
next_obs, next_done, obs, dones, actions, logprobs, values, step, agent_params, key
)
print("id(obs) after", id(obs))
which gives
id(obs) before 140230683526704
id(obs) after 140230683590064
@vwxyzjn yes, I think it's either for arrays created inside of jit or donated arguments.
Jitting the epochs in update_ppo() results in extremely high start up times for high epoch values and doesn't provide any speed after it's finally running. Bringing the epoch loop in the main function would fix that, like:
for _ in range(args.update_epochs):
agent_state, loss, pg_loss, v_loss, approx_kl, key = update_ppo(agent_state, storage, key)
I think it's worth changing to lax.scan
and fori_loop
. Removing the for loop within rollout
increases the speed quite a bit. Significantly reduces the complication time. I can make a pull request for this (and for compute_gae
and update_ppo
as well). I compared the original rollout
and the lax.scan
implementation and got the following results:
# Original for loop
Total data collection time: 135.69225978851318 seconds
Total data collection time without compilation: 98.75351285934448 seconds
Approx. compilation time: 36.93875765800476 seconds
# with lax.scan
Total data collection time: 60.91851544380188 seconds
Total data collection time without compilation: 60.029022455215454 seconds
Approx. compilation time: 0.8895087242126465 seconds
The command used is: python cleanrl/ppo_atari_envpool_xla_jax.py --env-id Breakout-v5 --total-timesteps 500000 --num-envs 32
Note: The training code was removed as the collection time correlates with the avg_episodic_length
, which depends on the random exploration and training dynamics. Removing the training part makes sure that the numbers in the test only relate to the rollout
function.
@51616 thanks for raising this issue. Could you share the snippet that derived these numbers?
Does lax.scan reduce the rollout time after compilation is finished? nvm I misread something. It’s interesting the rollout time after compilation is much faster, and this would be a good reason to consider using scan. Would you mind preparing the PR?
@vwxyzjn Here's the code
def step_once(carry, step, env_step_fn):
(agent_state, episode_stats, next_obs, next_done, storage, key, handle) = carry
storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)
episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)
storage = storage.replace(rewards=storage.rewards.at[step].set(reward))
return ((agent_state, episode_stats, next_obs, next_done, storage, key, handle), None)
def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step,
step_once_fn, max_steps):
(agent_state, episode_stats, next_obs, next_done, storage, key, handle), _ = jax.lax.scan(
step_once_fn,
(agent_state, episode_stats, next_obs, next_done, storage, key, handle), (), max_steps)
global_step += max_steps * args.num_envs
return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
rollout_fn = partial(rollout,
step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed),
max_steps=args.num_steps)
for update in range(1, args.num_updates + 1):
update_time_start = time.time()
agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step = rollout_fn(
agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step
)
if update == 1:
start_time_wo_compilation = time.time()
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
print("SPS_update:", int(args.num_envs * args.num_steps / (time.time() - update_time_start)))
writer.add_scalar(
"charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step
)
print("Total data collection time:", time.time() - start_time, "seconds")
print("Total data collection time without compilation:", time.time() - start_time_wo_compilation, "seconds")
print("Approx. compilation time:", start_time_wo_compilation - start_time, "seconds")
envs.close()
writer.close()
I can make a PR for this. I also think we should use the output of the lax.scan
as opposed to replacing the value inplace. Might look something like this
def step_once(carry, step, env_step_fn):
(agent_state, episode_stats, obs, done, key, handle) = carry
action, logprob, value, key = get_action_and_value(agent_state, obs, key)
episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)
storage = Storage(
obs=obs,
actions=action,
logprobs=logprob,
dones=done,
values=value,
rewards=reward,
returns=jnp.zeros_like(reward),
advantages=jnp.zeros_like(reward),
)
return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage)
def rollout(agent_state, episode_stats, next_obs, next_done, key, handle,
step_once_fn, max_steps):
(agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan(
step_once_fn,
(agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps)
return agent_state, episode_stats, next_obs, next_done, key, handle, storage
rollout_fn = partial(rollout,
step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed),
max_steps=args.num_steps)
for update in range(1, args.num_updates + 1):
update_time_start = time.time()
agent_state, episode_stats, next_obs, next_done, key, handle, storage = rollout_fn(
agent_state, episode_stats, next_obs, next_done, key, handle
)
if update == 1:
start_time_wo_compilation = time.time()
global_step += args.num_steps * args.num_envs
...
The code is a bit cleaner and uses the output from lax.scan
directly
@vwxyzjn Was there any reason why this wasn't merged in the end?
Nothing really. If you’d like free free to take on the PR :)
Description
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).