ayaka14732 / llama-2-jax

JAX implementation of the Llama 2 model
https://arxiv.org/abs/2307.09288
Creative Commons Zero v1.0 Universal
208 stars 23 forks source link

HF LLaMA Flax #9

Open sanchit-gandhi opened 1 year ago

sanchit-gandhi commented 1 year ago

Hey @ayaka14732! Super cool repo - thanks for working on this! With @vvvm23, we're working on adding the Flax LLaMA model to HF Transformers: https://github.com/huggingface/transformers/pull/24587 Just thought I'd let you know since it might be of interest to you, and potentially a model class that you can leverage in this repo for fast inference/training. It'll be a fully integrated version of the Flax LLaMA model in Transformers by the time of release.

ayaka14732 commented 1 year ago

Hi @sanchit-gandhi, thank you for letting me know! I am definitely interested and I will be sure to check out the progress!