Closed vwxyzjn closed 1 year ago
The latest updates on your projects. Learn more about Vercel for Git βοΈ
Name | Status | Preview | Comments | Updated |
---|---|---|---|---|
cleanrl | β Ready (Inspect) | Visit Preview | π¬ Add your feedback | Feb 19, 2023 at 2:03AM (UTC) |
Hey @51616 happy new year. This is the implementation that works with EnvPoolβs async API, which requires quite a bit a refactoring. It should scale to cases where the environments are slow or/and the models are large. Iβm running some benchmark experiments, but figured you might be interested in taking a look :)
@vwxyzjn Any specific part you want me to take a look at? Btw, I'm not quite familiar with async environments but I can help you review/test the code if needed.
Added cleanrl/sebulba_ppo_envpool.py for the podracer architecture β potentially helpful to #350. It does not work at all now.
Todo items:
SPS_update
calculation for the actor β
helps!jnp.array_split
within JIT π€ same speed!hypothesis:
update
function that runs on GPU1, then the execution of update
will block the jax.device_put_sharded
call in a separate thread that tries to put data from GPU0 to GPU1. Not sure if this is the case for TPU as well.Seems to work ok now... This opens up new possibilities because we can use SPMD for learner updates.
CC @kinalmehta @51616 @shermansiu. Btw @shermansiu don't worry about using this for muesli yet β a lot of work needs to be done for this PR before it is stable and usable.
Thanks, good to know!
Some SPS improvement... Interestingly, using 1 GPU for both actor and learner performs just every so slightly slower than using 1 GPU for the actor and 1 GPU for the learner.
Further more c2b18b5 experimented with pmap (SPMD) and worked really well with 2 GPUs (GPU A used for inference, GPU A and B used for SPMD)! Almost twice as fast as the baseline.
threading
is fine (multiprocessing
not necessary)sebulba_ppo_envpool_new.py
in ab732a6 basically running the learners non-stop while trying to step the actors as fast as possible. There's no communication between the actor and learners, so I was just testing to see if the learners would slow down the actor.
There were two settings:
--actor-device-ids 0 --learner-device-ids 1 2
actor and learners have separate devices--actor-device-ids 0 --learner-device-ids 0 1
actor shares one of the learners' devices.The following figures suggests --actor-device-ids 0 --learner-device-ids 1 2
has much higher SPS, meaning having separate devices is key to not slowing down the actor, especially when the learning time is long. (e.g., when learning time is short, --actor-device-ids 0 --learner-device-ids 0 1
might perform just as fast)
I think under the hood this means calling multi_device_update
will utilize the learners' devices and block access to those devices from other python threads, effectively slowing down the actor. However, if the actor has its own device, then the actor's speed is unaffected.
Supported multiple threads on an actor GPU, according to the podracer paper. Learning properly with multiple threads on an actor GPU is not tested for 1d85943.
To generate experience, we use (at least) one separate Python thread for each actor core [...] To make efficient use of the actor cores, it is essential that while a Python thread is stepping a batch of environments, the corresponding TPU core is not idle. This is achieved by creating multiple Python threads per actor core, each with its own batched environment. They threads alternate in using the same actor core, without manual synchronization
a93a1f5 calls jax.device_put_sharded
in the learner, which makes the actor thread runs a lot faster. The learning time also suffers very little time, which is great. My hypothesis is that:
update
function that runs on GPU1, then the execution of update
will block the jax.device_put_sharded
call in a separate thread that tries to put data from GPU0 to GPU1. Not sure if this is the case for TPU as well.device_put_sharded
from the actor threads, therefore unblocking the actor.Not sure the implication is for multi-GPU or TPU.
Note This is a major difference from the podracer paper, in which I presume the
jax.device_put_sharded
happens in the actor threads
0eefb50 uses jax.put_device_replicated
for the agent_state
which slightly improves SPS.
0dd591c blocks the actor for 10 seconds during the first, second, and the third rollout. Experiments found that it could improve SPS a bit and actually improves performance as well β without this commit, the actor is always generating experiences for outdated parameters (1 update behind; see stats/param_queue_size
)
Putting these items at the top of the PR.
Was able to match DM's sebubla's (https://arxiv.org/pdf/2104.06272.pdf) architecture performance to some capacity. My prototype (CleanBa PPO β stands for cleanrl's sebubla PPO) can outperform the original IMPALA (deep net setting, Espeholt et al., 2018), with 5 A100 GPUs (1 actor GPU, 4 learner GPU) and 16 CPU cores for envpool. Still WIP; need to fix some bugs.
Current results:
Got an error
2023-02-17 06:08:24.084633: E external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.cc:481] Failed to disconnect from coordination service with status: DEADLINE_EXCEEDED: Deadline Exceeded
Additional GRPC error information from remote target unknown_target_for_coordination_leader:
:{"created":"@1676614104.084285325","description":"Error received from peer ipv4:26.0.134.228:64719","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Deadline Exceeded","grpc_status":4}. Proceeding with agent shutdown anyway.
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
File "/admin/home-costa/.cache/pypoetry/virtualenvs/cleanrl-BE0ShDkT-py3.8/lib/python3.8/site-packages/jax/_src/distributed.py", line 168, in shutdown
global_state.shutdown()
File "/admin/home-costa/.cache/pypoetry/virtualenvs/cleanrl-BE0ShDkT-py3.8/lib/python3.8/site-packages/jax/_src/distributed.py", line 87, in shutdown
self.client.shutdown()
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: Deadline Exceeded
Additional GRPC error information from remote target unknown_target_for_coordination_leader:
:{"created":"@1676614104.084285325","description":"Error received from peer ipv4:26.0.134.228:64719","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Deadline Exceeded","grpc_status":4}
2023-02-17 06:08:24.405270: E external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service.cc:1129] Shutdown barrier in coordination service has failed: DEADLINE_EXCEEDED: Barrier timed out. Barrier_id: Shutdown::15630121101087999007 [type.googleapis.com/tensorflow.CoordinationServiceError='']. This suggests that at least one worker did not complete its job, or was too slow/hanging in its execution.
2023-02-17 06:08:24.405311: E external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service.cc:731] INTERNAL: Shutdown barrier has been passed with status: 'DEADLINE_EXCEEDED: Barrier timed out. Barrier_id: Shutdown::15630121101087999007 [type.googleapis.com/tensorflow.CoordinationServiceError='']', but this task is not at the barrier yet. [type.googleapis.com/tensorflow.CoordinationServiceError='']
2023-02-17 06:08:24.405379: E external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service.cc:449] Stopping coordination service as shutdown barrier timed out and there is no service-to-client connection.
2023-02-17 06:08:53.225236: E external/org_tensorflow/tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.cc:711] Coordination agent is in ERROR: INVALID_ARGUMENT: Unexpected task request with task_name=/job:jax_worker/replica:0/task:0
Additional GRPC error information from remote target unknown_target_for_coordination_leader:
:{"created":"@1676614133.225164768","description":"Error received from peer ipv4:26.0.134.228:64719","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Unexpected task request with task_name=/job:jax_worker/replica:0/task:0","grpc_status":3} [type.googleapis.com/tensorflow.CoordinationServiceError='']
2023-02-17 06:08:53.225277: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.cc:452] Coordination service agent in error status: INVALID_ARGUMENT: Unexpected task request with task_name=/job:jax_worker/replica:0/task:0
Additional GRPC error information from remote target unknown_target_for_coordination_leader:
:{"created":"@1676614133.225164768","description":"Error received from peer ipv4:26.0.134.228:64719","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Unexpected task request with task_name=/job:jax_worker/replica:0/task:0","grpc_status":3} [type.googleapis.com/tensorflow.CoordinationServiceError='']
2023-02-17 06:08:53.226009: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/client.h:75] Terminating process because the coordinator detected missing heartbeats. This most likely indicates that another task died; see the other task logs for more details. Status: INVALID_ARGUMENT: Unexpected task request with task_name=/job:jax_worker/replica:0/task:0
Additional GRPC error information from remote target unknown_target_for_coordination_leader:
:{"created":"@1676614133.225164768","description":"Error received from peer ipv4:26.0.134.228:64719","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Unexpected task request with task_name=/job:jax_worker/replica:0/task:0","grpc_status":3} [type.googleapis.com/tensorflow.CoordinationServiceError='']
srun: error: ip-26-0-134-228: task 0: Aborted
Closed in favor of https://github.com/vwxyzjn/cleanba
Description
Todo items:
SPS_update
calculation for the actor β helps!jnp.array_split
within JIT π€ same speed!async_update
is off (https://wandb.ai/costa-huang/cleanRL/reports/4-vs-3-learner-devices--VmlldzozNDg4MDE5)More experiments
hypothesis:
update
function that runs on GPU1, then the execution ofupdate
will block thejax.device_put_sharded
call in a separate thread that tries to put data from GPU0 to GPU1. Not sure if this is the case for TPU as well.Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.