kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

Is "to_hf_weights.py" specific to "6B_roto_256.json" only? #191

Open leejason opened 2 years ago

leejason commented 2 years ago

Is "to_hf_weights.py" specific to "6B_roto_256.json" only? I was trying to make this codebase work for smaller models (e.g., "layers": 12, "d_model": 768, "n_heads": 16). However, the HF model produced by "to_hf_weights.py" generates very strange results on GPU, while "device_sample.py" works fine on TPU VM.

After several hours with different combinations (e.g., fp16, bf16, fp32) of "to_hf_weights.py" with/without "slim_model.py", it seems impossible to build the HF model for the following sample code to produce plausible results:

model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16) https://huggingface.co/docs/transformers/model_doc/gptj

No idea what I might have done wrong. I wonder: is it possible that "to_hf_weights.py" is not compatible with smaller models? Or, is it possible to know more details in the following?

Running with HuggingFace

To use the model in HuggingFace's transformer library using pytorch, you'll need to transfer the weights into a format that it recognizes. This can be done using to_hf_weights.py. It's recommended that you use slim_model.py before attempting to move the weights to a pytorch/transformer format. Use python to_hf_weights.py --help to see usage details.

https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md

For example, what were the arguments of the command line(s) to produce the "EleutherAI/gpt-j-6B" model hosted on HF? I'd like to follow the same steps of preparing the "EleutherAI/gpt-j-6B" model on HF.

Thank you for any advice.