kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

Colab version now breaks on "import optax" #162

Closed mesotron closed 2 years ago

mesotron commented 2 years ago

Colab notebook used to work fine, now breaks on "import optax" with error "No module named 'optax'". (Yes, this is the cell that says "Sometimes the next step errors for some reason, just run it again ¯_(ツ)_/¯", but repeated attempts don't work. This used to work every time).

If I keep on !pip installing every library that it can't find, it still fails when it gets to "from mesh_transformer.checkpoint import read_ckpt_lowmem" with the error "Cannot subclass <class 'typing._SpecialForm'>".

mesotron commented 2 years ago

Realized this was a duplicate of the previous issue, now closed.