kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

having issue in running the model #192

Closed whoislimshady closed 2 years ago

whoislimshady commented 2 years ago

Screenshot from 2022-02-08 16-38-49

i am trying to use model for inference but getting this error while running the script python script is same as colab demo code can someone help me with this

edit i am using jax[tpu]>=0.2.16 cause when i tried using 0.2.12 as given in colab i am getting error saying Screenshot from 2022-02-08 21-37-30