pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.2k stars 407 forks source link

support for gemma-2 #1813

Open almugabo opened 1 week ago

almugabo commented 1 week ago

Could you please add fine-tuning support for gemma-2 ? It has good good multilingual capabilities and is a good candidate for fine-tuning for languages other than English. Its different sizes also make it attractive for fine-tuning for different tasks. I would gladly help but am not knowledgeable enough Thank you

krammnic commented 1 week ago

Actually, in inspiration of one current kaggle competition - it is really good idea too add this pretty soon.

ebsmothers commented 1 week ago

Thanks @almugabo for creating the issue. I think this will be a bit of effort, quickly jotting down a couple of things I'm aware of that we'd need to support:

For logit softcapping and sliding window attention, I suspect we can use FlexAttention APIs. See this blog post where they give explicit examples of each.

Optimox commented 1 week ago

Hello, I have started the addition of gemma2, I will create now my PR in WIP mode. I haven't run any test yet but will do soon!

Edit: My PR is here: #1835

@ebsmothers it would be great if you could have a quick look to validate the choices I made in order to implement sliding windows, pre-post layer normalisation and softcapping... I would be happy to make things differently to keep the changes minimal (I tried as much as possible to keep all changes minimal).

Optimox commented 1 week ago

I didn't know about FlexAttention, I will look into it!