stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
510 stars 80 forks source link

Pip install gives error on TPU clustere #700

Open abhinavg4 opened 2 months ago

abhinavg4 commented 2 months ago

The following setup commands:

pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install levanter

Given the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/.local/lib/python3.10/site-packages/levanter/__init__.py", line 4, in <module>
    import levanter.distributed as distributed
  File "/home/ubuntu/.local/lib/python3.10/site-packages/levanter/distributed.py", line 12, in <module>
    from jax._src.clusters import SlurmCluster, TpuCluster
ImportError: cannot import name 'TpuCluster' from 'jax._src.clusters' (/home/ubuntu/.local/lib/python3.10/site-packages/jax/_src/clusters/__init__.py)

Building from source works fine

dlwh commented 2 months ago

the latest levanter release is super old. you need to use the git install