Hi,
I am trying to train GPT Neo on TPU based on the guide section you are providing: configure_flax.md
However, I get the following error:
File "run_clm_mp.py", line 41, in <module> from jax.experimental.maps import mesh ImportError: cannot import name 'mesh' from 'jax.experimental.maps' (/home/ali.najafi/flax/lib/python3.8/site-packages/jax/experimental/maps.py)
Hi, I am trying to train GPT Neo on TPU based on the guide section you are providing: configure_flax.md
However, I get the following error:
File "run_clm_mp.py", line 41, in <module> from jax.experimental.maps import mesh ImportError: cannot import name 'mesh' from 'jax.experimental.maps' (/home/ali.najafi/flax/lib/python3.8/site-packages/jax/experimental/maps.py)
lib version: