young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

[Bug] Error in Evaluation #111

Open LeoXinhaoLee opened 4 months ago

LeoXinhaoLee commented 4 months ago

Hi, when running llama_train.py distributedly on a v3-512 tpu pod, when I turn on evaluation (eval_steps > 0), I got this error:

RuntimeError: Running operations on `Array`s that are not fully addressable by this process (
i.e. `Array`s with data sharded across multiple devices and processes.) is dangerous. It’s ve
ry important that all processes run the same cross-process computations in the same order oth
erwise it can lead to hangs. If you’re not already familiar with JAX’s multi-process programm
ing model, please read https://jax.readthedocs.io/en/latest/multi_process.html. To fix this e
rror, run your `jitted` computation inside `with jax.spmd_mode('allow_all'):` context manager
.

This happens at this line in code :

average_metrics(eval_metric_list)

Could you please help me with this? Thank you very much!