kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

Error: AssertionError: Incompatible checkpoints (8,) vs (8, 4096) #198

Closed ljj430 closed 2 years ago

ljj430 commented 2 years ago

When I ran network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1]) in the Colab demo, there is an error 'AssertionError: Incompatible checkpoints (8,) vs (8, 4096)'. It seems like devices.shape[1] is (8,).

Can you help me to fix this problem?

Thank you so much!

ljj430 commented 2 years ago

I fixed it by downloading the slim weights.