instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
613 stars 79 forks source link

fix(training): support training and evaluation on multiple (tpu) workers #152

Closed clement-bonnet closed 1 year ago

clement-bonnet commented 1 year ago

Is your feature request related to a problem? Please describe

Right now, training and evaluators assume local devices only. This works on TPU until a TPU-v-8, but breaks from TPU-v-16 on.

Describe the solution you'd like

Use jax.process_index() to select different keys for each host.