kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

smaller models #177

Closed leejason closed 2 years ago

leejason commented 2 years ago

Can I use any of the following model configs with Mesh Transformer JAX? Are they compatible?

My purpose is to experiment with smaller models before pretraining a new GPT-J-6B from scratch with different text. Thank you for your advice.

kingoflolz commented 2 years ago

No, gpt-neo models are not compatible with this codebase

leejason commented 2 years ago

gpt-neo models are not compatible with this codebase

Thank you very much for the update. How can I make it work if I modify the "6B_roto_256.json" with the following parameters?

"layers": 12, "d_model": 768,
"n_heads": 16, "n_vocab": 50400,

The purpose is to make the model size smaller for quicker experiments. In fact, I've trained the model from scratch with my own vocabulary (and a trained tokenizer). The inference in TPU VM seems fine, but the inference on GPU after using "to_hf_weights.py" produces very strange results. I also tried "slim_model.py," but it is not helpful.

If possible, could you advise on how to make the codebase compatible with smaller models? What code can I enhance?