Open mayankagarwals opened 1 year ago
@sanchit-gandhi
@sanchit-gandhi @sgugger Are there any reservations around this? I have gone through GPT architecture and flax code of GPT2. I'm fairly certain this is implementable for exhaustiveness. OpenAI GPT model still sees almost a million downloads a month
Please let me know. Would like to start with a draft PR than just rushing in
Hey @mayankagarwals! Super sorry for not getting back to you earlier here. Let me give you my two cents: the OpenAI GPT model is definitely still super popular amongst PyTorch users (as you say, ~1 mil downloads per month). What we tend to see with Flax users though is a preference for newer, larger models (e.g. OPT, Flan-T5). This is primarily because of how easy it is to run super large models in JAX with data and model parallelism. So whilst I think this PR would be cool for completeness, I think porting a newer, more flashy model might get the JAX/Flax community more excited! How does this sound?
No worries :) @sanchit-gandhi Yes, I had not gone ahead because of the same skepticism. Would you mind pointing me to what in your opinion might be a model worth digging into and think will benefit hugging face and the community? I have a good hold on text generation architecture so something aligned there would be better!
LLaMA could be cool! What I would suggest doing is starting from the Flax GPT-Neo model (since this is the Flax model most similar to LLaMa) and then adding the new bits in
@sanchit-gandhi I was also thinking of adding a Flax version of LLama (and also GPT-NeoX, maybe others) as some Flax practice. I couldn't find a guide on adding a new framework to an existing model, and I asked on the discord without much avail (but was directed to this issue).
I'm familiar with the architectures having already ported them to other frameworks where I work.
If you could point me in the right direction, I would be happy to port this for you! I wasn't sure if it is as simple as adding a new modeling_flax_*
file or if there are more parts / some best practices to be aware of.
Thanks 🤗
Hey @vvvm23! In this case, since we already have the PT model, the best thing to do would be to add a new modelling file for flax (modeling_flax_llama.py
) which is initially copied from the Flax GPT Neo modelling code. You can then start making changes to the Flax code to adapt it to LLama. The reason that we copy from Flax GPT Neo is that it contains optimised code for the attention layer which we should try and re-use for Flax LLama.
You'll then need to make sure that the weight names match and that you have equivalence between PyTorch LLama and Flax LLama. To do this, I would recommend creating a 'dummy' version of the PyTorch LLama model:
from transformers import LlamaConfig, LlamaForCausalLM
config = LlamaConfig(hidden_size=16, intermediate_size=24, max_position_embeddings=128, num_attention_heads=2, num_hidden_layers=2)
model = LlamaForCausalLM(config)
model.save_pretrained("./path/to/save")
And then for your test script, load this same model in PyTorch, then Flax (pass from_pt=True
in the from_pretrained
call), and verify with random inputs that you get the same logits out when you do a forward pass (example here https://github.com/huggingface/transformers/issues/15476#issue-1121800731)
You can then focus on the tests and converting the actual model weights as required. Feel free to open a PR and tag me - more than happy to help with the integration here!
Thanks @sanchit-gandhi that was very comprehensive! I'll let you know how I get on. :hugs:
Got a bit caught up with real life stuff, but I will be working on this more intensively from Monday, aiming to finish something by end of week.
@sanchit-gandhi I made a draft PR of my current progress, see #24587. Sorry, I haven't made the full model, been very busy 😓
Model description
https://huggingface.co/openai-gpt today supports tf and pytorch but not flax. I'd like to implement the support to enhance the current gpt offering by hugging face
Open source status
Provide useful links for the implementation
Given that the model is already implemented in other two frameworks, I'll try to infer the model from there. Please feel free to provide additional resources that can help me wrap this up better and faster