kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

The latest update to PIP breaks installation #210

Open markriedl opened 2 years ago

markriedl commented 2 years ago

The latest update of PIP seems to have changed dependency resolution such that mesh-transformer-jax hard to install. Installation now requires an older version of pip, some specific versions of packages like 'transformers' and some forced ordering of packages.

These are the steps I had to do:

  1. pip install pip==22.0
  2. Edit requirements.txt:
    • transformers==4.16.2
    • fastapi==0.73.0
    • uvicorn==0.17.1
  3. pip install jax==0.2.12 tensorflow==2.5.0 (as before but has to come earlier)
  4. pip install -r mesh-transformer-jax/requirements.txt
  5. pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0
danyaljj commented 2 years ago

For me too, the installation of the requirements takes forever.

jchwenger commented 2 years ago

Same here, even when following @markriedl's steps, pip still searches for hours. I know these days it would be more straightforward to just use the model at goose.ai, but I wanted to make some experiments with sampling etc. It'd be great if it could be made to work again!

markriedl commented 2 years ago

I wonder if I didn't transcribe the order of operations right. This notebook won't take too long to run: https://colab.research.google.com/drive/17zvUhLcpjUKJdTRg00HYdGMEN3uoMy-M?usp=sharing

jchwenger commented 2 years ago

Thanks! It works again.

PhilWicke commented 2 years ago

I wonder if I didn't transcribe the order of operations right. This notebook won't take too long to run: https://colab.research.google.com/drive/17zvUhLcpjUKJdTRg00HYdGMEN3uoMy-M?usp=sharing

Thanks for this fix. I wonder if I'm the only one for whom network.state = read_ckpt(network.state, "step_383500/", devices.shape[1]) is causing a RAM overflow terminating the session?

lightyrs commented 2 years ago

I'm getting AttributeError: module 'jax.random' has no attribute 'KeyArray' on import optax with your example @markriedl

lightyrs commented 2 years ago

Sorry @markriedl I see the solution posted in https://github.com/kingoflolz/mesh-transformer-jax/issues/221