Open samyakai opened 2 years ago
Getting the same issue not able to solve Please help us Thank you
Maybe this works.
pip install dm-haiku==0.0.5
and put optax back to the default version.
@mosmos6 Nope it doesn't work
https://github.com/kingoflolz/mesh-transformer-jax/issues/202#issuecomment-1050887576 Please follow this solution! It works.
@anon-mouse-1 which v2 version of TPU should i use? There are 2 options for TPU namely TPU VM architecture and tpu node architecture.
After the 5th error I just gave up on this notebook.
@Tylersuard Yes. Also no one is providing a solution to the errors which is a shame as I really want to train on TPU as opposed to a GPU
Downgrading optax worked for me to get rid of this error.
pip install optax==0.0.9
#202 (comment) Please follow this solution! It works.
In addition to this,
pip install chex==0.1.2
pip install jaxlib==0.1.74
pip install dm-haiku==0.0.5
and it worked for me.
I am trying to fine tune gpt-j on custom data using TPU. When I try to run the "device_train.py" file using the mentioned command: "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/", I get this error:
Traceback (most recent call last): File "device_train.py", line 13, in
from mesh_transformer import util
File "/home/shreyjain/mesh-transformer-jax/mesh_transformer/util.py", line 36, in
class ClipByGlobalNormState(OptState):
File "/usr/lib/python3.8/typing.py", line 317, in new
raise TypeError(f"Cannot subclass {cls!r}")
TypeError: Cannot subclass <class 'typing._SpecialForm'>
OS = Ubuntu 20.04 TPU V3-8 python version = 3.8 and 3.7 both give the error
I have no idea what this error means. Any help would be appreciated! Thank you.