Hi, when running llama_train.py distributedly on a v3-512 tpu pod, when I turn on evaluation (eval_steps > 0), I got this error:
RuntimeError: Running operations on `Array`s that are not fully addressable by this process (
i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s ve
ry important that all processes run the same cross-process computations in the same order oth
erwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programm
ing model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this e
rror, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager
.
Hi, when running
llama_train.py
distributedly on a v3-512 tpu pod, when I turn on evaluation (eval_steps > 0
), I got this error:This happens at this line in code :
Could you please help me with this? Thank you very much!