kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

gpt-neo models are not compatible with this codebase #189

Open leejason opened 2 years ago

leejason commented 2 years ago

No, gpt-neo models are not compatible with this codebase

Thank you very much for the update. How can I make it work if I modify the "6B_roto_256.json" with the following parameters (from gpt-neo codbase)?

"layers": 12, "d_model": 768, "n_heads": 16, "n_vocab": 50400,

The purpose is to make the model size smaller for quicker GPT-J experiments. In fact, I've trained the model from scratch with my own vocabulary (and a trained tokenizer). The inference in TPU VM seems fine, but the inference on GPU after using "to_hf_weights.py" produces very strange results. I also tried "slim_model.py," but it is not helpful.

If possible, could you advise on how to make GPT-J codebase compatible with smaller models? What code can I enhance? Thank you.

Originally posted by @kingoflolz in https://github.com/kingoflolz/mesh-transformer-jax/issues/177#issuecomment-1023958725