NVIDIA / JAX-Toolbox

JAX-Toolbox
Apache License 2.0
230 stars 43 forks source link

'jax.experimental.maps' import error #962

Open MikeMpapa opened 1 month ago

MikeMpapa commented 1 month ago

Hi I am trying to use the levanter image but I get the following error: ModuleNotFoundError: No module named 'jax.experimental.maps'.

Was the model renamed? It worked fine yesterday

Thanks!

The complete error log:


  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'```
chaserileyroberts commented 1 month ago

I recently hit this error too on a separate problem.

I think that jax just removed maps from experimental recently. That has been deprecated for a while https://github.com/google/jax/blob/5e418f5ab2692d4791816e85ed82eb0834a579cb/CHANGELOG.md?plain=1#L284

HMUNACHI commented 1 month ago

Problem: This problem is from the Haliax package: see here, thread_resource has moved into jax.experimental.mesh_utils.py with recent refactoring. They need to change that.

Solution: You can fork Haliax repo yourself, fix the importation problem and replace the Haliax link in the said docker image here with link to your yours.