huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.51k stars 26.9k forks source link

Open AI GPT Model Implementation in Flax #22647

Open mayankagarwals opened 1 year ago

mayankagarwals commented 1 year ago

Model description

https://huggingface.co/openai-gpt today supports tf and pytorch but not flax. I'd like to implement the support to enhance the current gpt offering by hugging face

Open source status

Provide useful links for the implementation

Given that the model is already implemented in other two frameworks, I'll try to infer the model from there. Please feel free to provide additional resources that can help me wrap this up better and faster

mayankagarwals commented 1 year ago

@sanchit-gandhi

mayankagarwals commented 1 year ago

@sanchit-gandhi @sgugger Are there any reservations around this? I have gone through GPT architecture and flax code of GPT2. I'm fairly certain this is implementable for exhaustiveness. OpenAI GPT model still sees almost a million downloads a month

Please let me know. Would like to start with a draft PR than just rushing in

sanchit-gandhi commented 1 year ago

Hey @mayankagarwals! Super sorry for not getting back to you earlier here. Let me give you my two cents: the OpenAI GPT model is definitely still super popular amongst PyTorch users (as you say, ~1 mil downloads per month). What we tend to see with Flax users though is a preference for newer, larger models (e.g. OPT, Flan-T5). This is primarily because of how easy it is to run super large models in JAX with data and model parallelism. So whilst I think this PR would be cool for completeness, I think porting a newer, more flashy model might get the JAX/Flax community more excited! How does this sound?

mayankagarwals commented 1 year ago

No worries :) @sanchit-gandhi Yes, I had not gone ahead because of the same skepticism. Would you mind pointing me to what in your opinion might be a model worth digging into and think will benefit hugging face and the community? I have a good hold on text generation architecture so something aligned there would be better!

sanchit-gandhi commented 1 year ago

LLaMA could be cool! What I would suggest doing is starting from the Flax GPT-Neo model (since this is the Flax model most similar to LLaMa) and then adding the new bits in

vvvm23 commented 1 year ago

@sanchit-gandhi I was also thinking of adding a Flax version of LLama (and also GPT-NeoX, maybe others) as some Flax practice. I couldn't find a guide on adding a new framework to an existing model, and I asked on the discord without much avail (but was directed to this issue).

I'm familiar with the architectures having already ported them to other frameworks where I work.

If you could point me in the right direction, I would be happy to port this for you! I wasn't sure if it is as simple as adding a new modeling_flax_* file or if there are more parts / some best practices to be aware of.

Thanks 🤗

sanchit-gandhi commented 1 year ago

Hey @vvvm23! In this case, since we already have the PT model, the best thing to do would be to add a new modelling file for flax (modeling_flax_llama.py) which is initially copied from the Flax GPT Neo modelling code. You can then start making changes to the Flax code to adapt it to LLama. The reason that we copy from Flax GPT Neo is that it contains optimised code for the attention layer which we should try and re-use for Flax LLama.

You'll then need to make sure that the weight names match and that you have equivalence between PyTorch LLama and Flax LLama. To do this, I would recommend creating a 'dummy' version of the PyTorch LLama model:

from transformers import LlamaConfig, LlamaForCausalLM

config = LlamaConfig(hidden_size=16, intermediate_size=24, max_position_embeddings=128, num_attention_heads=2, num_hidden_layers=2)

model = LlamaForCausalLM(config)
model.save_pretrained("./path/to/save")

And then for your test script, load this same model in PyTorch, then Flax (pass from_pt=True in the from_pretrained call), and verify with random inputs that you get the same logits out when you do a forward pass (example here https://github.com/huggingface/transformers/issues/15476#issue-1121800731)

You can then focus on the tests and converting the actual model weights as required. Feel free to open a PR and tag me - more than happy to help with the integration here!

vvvm23 commented 1 year ago

Thanks @sanchit-gandhi that was very comprehensive! I'll let you know how I get on. :hugs:

vvvm23 commented 1 year ago

Got a bit caught up with real life stuff, but I will be working on this more intensively from Monday, aiming to finish something by end of week.

vvvm23 commented 1 year ago

@sanchit-gandhi I made a draft PR of my current progress, see #24587. Sorry, I haven't made the full model, been very busy 😓