Closed clement-bonnet closed 1 year ago
Right now, training and evaluators assume local devices only. This works on TPU until a TPU-v-8, but breaks from TPU-v-16 on.
Use jax.process_index() to select different keys for each host.
jax.process_index()
Is your feature request related to a problem? Please describe
Right now, training and evaluators assume local devices only. This works on TPU until a TPU-v-8, but breaks from TPU-v-16 on.
Describe the solution you'd like
Use
jax.process_index()
to select different keys for each host.