world-modelz / dreamax

A scalable Dreamer implementation in JAX
MIT License
11 stars 2 forks source link

Data Parallelism on multiple accelerator devices #5

Open XMaster96 opened 2 years ago

XMaster96 commented 2 years ago

In order to fully utilise a TPU VM, we need to be able to run on multiple accelerator devices. This can be done by running multiple independent Rollout workers on separate devices (#4 ) or by running training itself on multiple devices. In order to use multiple devices with JAX we have to manually replicated and reduce the params and data Tensors.

AC: