kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

Colab Demo Notebook Not Working #167

Open CircuitGuy opened 2 years ago

CircuitGuy commented 2 years ago

I launched the Colab notebook to try and demo this model.

There's a section that starts with: Sometimes the next step errors for some reason, just run it again ¯\_(ツ)_/¯

That's fine and all, except running it multiple times didn't help. To try and resolve some errors, I tried to: !pip install optax transformers ray

That got me closer, but it errors out with:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-18-a22d9a83aa66> in <module>()
      7 import transformers
      8 
----> 9 from mesh_transformer.checkpoint import read_ckpt_lowmem
     10 from mesh_transformer.sampling import nucleaus_sample
     11 from mesh_transformer.transformer_shard import CausalTransformer

2 frames
/usr/lib/python3.7/typing.py in __new__(cls, *args, **kwds)
    308                 isinstance(args[1], tuple)):
    309             # Close enough.
--> 310             raise TypeError(f"Cannot subclass {cls!r}")
    311         return super().__new__(cls)
    312 

TypeError: Cannot subclass <class 'typing._SpecialForm'>
sharaku17 commented 2 years ago

I am witnessing the same error. Would appreciate if somebody can help

texturejc commented 2 years ago

Yes, I'm also getting this exact error. Any feedback appreciated!

sharaku17 commented 2 years ago

I was able to solve this, let the first cell run where all requirements from requirements.txt are installed and afterwards run the two following pip installs:

!pip install optax==0.0.9 transformers dm-haiku einops

and

!pip install ray

after installing this I was able to run the following cells in the notebook without any problems

See: #161 Aspie96 Comment

joan0fsnark commented 2 years ago

I'm getting this exact error and nothing has resolved it, including @Aspie96's helpful installation info.

I get two errors at this stage.

Sometimes the next step errors for some reason, just run it again ¯_(ツ)_/¯

---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

[<ipython-input-16-a22d9a83aa66>](https://localhost:8080/#) in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8
---------------------------------------------------------------------------
[/usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>()
     34 Scalar = Union[float, int]
     35 Numeric = Union[Array, Scalar]
---> 36 PRNGKey = jax.random.KeyArray
     37 PyTreeDef = type(jax.tree_structure(None))
     38 Shape = jax.core.Shape

AttributeError: module 'jax.random' has no attribute 'KeyArray'
Raval-Arth commented 2 years ago

getting error on,


ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-2-a22d9a83aa66>](https://localhost:8080/#) in <module>()
      7 import transformers
      8 
----> 9 from mesh_transformer.checkpoint import read_ckpt_lowmem
     10 from mesh_transformer.sampling import nucleaus_sample
     11 from mesh_transformer.transformer_shard import CausalTransformer

ModuleNotFoundError: No module named 'mesh_transformer'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
ruze00 commented 2 years ago

This whole project has been quite the waste of time.

JohnnyOpcode commented 2 years ago

I beg to differ. It takes some effort to get the right mix of dependencies, but when it works it works quite well.

Here is what I use in my Colab Pro to bootstrap the model.

These are the packages that work and bypass requirements.txt (which is outdated/quirky on dependencies)

!pip install numpy~=1.21.0 !pip install typing-extensions~=3.7.4 !pip install tqdm>=4.45.0 !pip install wandb>=0.11.2 !pip install einops~=0.3.0 !pip install requests~=2.25.1 !pip install fabric~=2.6.0 !pip install optax==0.0.9 !pip install dm-haiku==0.0.5 !pip install git+https://github.com/EleutherAI/lm-evaluation-harness/ !pip install ray[default]==1.4.1 !pip install jaxlib~=0.1.68 !pip install jax~=0.2.12 !pip install Flask~=1.1.2 !pip install cloudpickle~=1.3.0 !pip install tensorflow-cpu~=2.7.0 !pip install google-cloud-storage~=1.36.2 !pip install transformers !pip install smart_open[gcs] !pip install func_timeout !pip install ftfy !pip install fastapi !pip install uvicorn !pip install lm_dataformat !pip install pathy

!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.7.0 chex==0.1.2 jaxlib==0.1.68

jweber00 commented 1 year ago

I have successfully installed and run the application using the recommended dependencies, two packages required a higher version level.

hovanesgasparian commented 1 year ago

Thank you so much @JohnnyOpcode! I was going nuts trying to figure out all the errors. Using your list, I was able to proceed and run the notebook to completion!

JohnnyOpcode commented 1 year ago

Thank you so much @JohnnyOpcode! I was going nuts trying to figure out all the errors. Using your list, I was able to proceed and run the notebook to completion!

There is another recent issue posted where wheel versions are further tweaked. GPT-J on JAX is brilliant work and just needs some love and attention when it comes to dependencies. Have a look at that and maybe it will reap further rewards.

Bravo to @kingoflolz on this outstanding bit (pun) of work.

Aspie96 commented 1 year ago

Well, it's obvious that it needs attention. It's all anyone needs.