google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
278 stars 36 forks source link

Gemma 2 #95

Closed talumbau closed 3 weeks ago

talumbau commented 1 month ago

Support Gemma 2

 -  add support for local sliding window attention
 -  add config for pre- and post- Feed forward layer norm
 -  add config for final logit softcap
 -  SDPA with logit softcap (HLFB not supported yet)
 -  Gemma2Block for immediate post-attention norm
 -  ModelConfig for Gemma 2
 -  Matches goldens from Gemma PyTorch implementation.
a8nova commented 4 weeks ago

Hello! I tried converting using this PR but I had issues, I understand this is a WIP but I still wanted to try. For example in the tensor mappings where is the tensor embedder coming from? I had to modify that to get the converter to work. In the gemma 2 source code it is called embed_tokens. Thanks!

Can i use this PR to convert gemma 2 weights?

haozha111 commented 3 weeks ago

can we add a test in https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/test/test_model_conversion.py ?

talumbau commented 3 weeks ago

can we add a test in https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/test/test_model_conversion.py ?

See the new test in test_model_conversion.py. Two notes: