kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

TypeError: Cannot subclass <class 'typing._SpecialForm'> while fine tuning #222

Open samyakai opened 2 years ago

samyakai commented 2 years ago

I am trying to fine tune gpt-j on custom data using TPU. When I try to run the "device_train.py" file using the mentioned command: "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/", I get this error:

Traceback (most recent call last): File "device_train.py", line 13, in from mesh_transformer import util File "/home/shreyjain/mesh-transformer-jax/mesh_transformer/util.py", line 36, in class ClipByGlobalNormState(OptState): File "/usr/lib/python3.8/typing.py", line 317, in new raise TypeError(f"Cannot subclass {cls!r}") TypeError: Cannot subclass <class 'typing._SpecialForm'>

OS = Ubuntu 20.04 TPU V3-8 python version = 3.8 and 3.7 both give the error

I have no idea what this error means. Any help would be appreciated! Thank you.

jagruti-samyak commented 2 years ago

Getting the same issue not able to solve Please help us Thank you

mosmos6 commented 2 years ago

Maybe this works.

pip install dm-haiku==0.0.5 and put optax back to the default version.

shrey10926 commented 2 years ago

@mosmos6 Nope it doesn't work

anon-mouse-1 commented 2 years ago

https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576 Please follow this solution! It works.

jagruti-samyak commented 2 years ago

@anon-mouse-1 which v2 version of TPU should i use? There are 2 options for TPU namely TPU VM architecture and tpu node architecture.

Tylersuard commented 2 years ago

After the 5th error I just gave up on this notebook.

samyakai commented 2 years ago

@Tylersuard Yes. Also no one is providing a solution to the errors which is a shame as I really want to train on TPU as opposed to a GPU

dhruv2601 commented 2 years ago

Downgrading optax worked for me to get rid of this error.

pip install optax==0.0.9

mosmos6 commented 2 years ago

#202 (comment) Please follow this solution! It works.

In addition to this, pip install chex==0.1.2 pip install jaxlib==0.1.74 pip install dm-haiku==0.0.5

and it worked for me.