In order to fully utilise a TPU VM, we need to be able to run on multiple accelerator devices.
This can be done by running multiple independent Rollout workers on separate devices (#4 ) or by running training itself on multiple devices. In order to use multiple devices with JAX we have to manually replicated and reduce the params and data Tensors.
AC:
JAX vmap training loop over multiple accelerator devices.
JAX replicated params and optimizer states from the training world model instants over multiple accelerator devices.
Reshape data from Dataloader for vmap.
Reduce loging output from the training loop back to a singular instants.
Reduce params from the training world model instants back to a singular device to update the master copy.
In order to fully utilise a TPU VM, we need to be able to run on multiple accelerator devices. This can be done by running multiple independent Rollout workers on separate devices (#4 ) or by running training itself on multiple devices. In order to use multiple devices with JAX we have to manually replicated and reduce the params and data Tensors.
AC: