kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

Can "slim_model.py" work with "d_model" as 768? #217

Open leejason opened 2 years ago

leejason commented 2 years ago

I updated "6B_roto_256.json" with the following for trying a smaller model.

"d_model": 768

The pretraining works on one TPU v3-8, but the slimmed model after using "slim_model.py" produces gibberish results.

Why? Does "slim_model.py" work with "d_model: 4096" only? I don't think so but I find no clue after tracing source code for hours.

Thank you for some light.