kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

TypeError: __init__() takes 2 positional arguments but 4 were given #225

Open ghost opened 2 years ago

ghost commented 2 years ago
Traceback (most recent call last):
  File "/Users/macos/Desktop/AI/serve.py", line 9, in <module>
    from mesh_transformer.checkpoint import read_ckpt
  File "/Users/macos/opt/anaconda3/lib/python3.9/site-packages/mesh_transformer/checkpoint.py", line 14, in <module>
    from mesh_transformer.util import head_print
  File "/Users/macos/opt/anaconda3/lib/python3.9/site-packages/mesh_transformer/util.py", line 36, in <module>
    class ClipByGlobalNormState(OptState):
TypeError: __init__() takes 2 positional arguments but 4 were given
AidanShipperley commented 1 year ago

Your version of jax is too updated. Try to use the dependencies listed in https://github.com/kingoflolz/mesh-transformer-jax/issues/246