kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Error! #205

Closed DimIsaev closed 2 years ago

DimIsaev commented 2 years ago

python3 device.serve.py

Traceback (most recent call last):
  File "device_serve.py", line 11, in <module>
    from mesh_transformer import util
  File "/home/dim/mesh-transformer-jax/mesh_transformer/util.py", line 36, in <module>
    class ClipByGlobalNormState(OptState):
  File "/usr/lib/python3.8/typing.py", line 317, in __new__
    raise TypeError(f"Cannot subclass {cls!r}")
TypeError: Cannot subclass <class 'typing._SpecialForm'>
samyakai commented 2 years ago

@DimIsaev How did you solve the error? I am getting the same error while fine tuning