google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
278 stars 36 forks source link

Allow explictly setting head_dim instead of deducing it. #87

Closed mbrenon closed 1 month ago

mbrenon commented 1 month ago

This is a requirement for the upcoming OpenELM models.

Note that head_dim is also moved down to the attention config: this parameter is related to attention so it has no reason to live under the main model config.

Please review carefully to make sure I'm not breaking anything here :)

BUG=https://b.corp.google.com/issues/352478939

talumbau commented 1 month ago

I ran inference on the example decoder models with this branch and everything looked good. Also converted tiny_llama with no issues. Would be great to get an official thumbs up from someone who has worked with the Stable Diffusion model.

haozha111 commented 1 month ago

@yichunk could you help review the changes to SD code? thanks!

talumbau commented 1 month ago

Actually, I see we need another change for the T5 encoder/decoder models, similar to the other decoder-only models. Please add

 head_dim=64

To the AttentionConfig in get_model_config_t5() function in t5.py. Thanks @mbrenon!

haozha111 commented 1 month ago

@yichunk could you help approve this PR? thanks!

talumbau commented 1 month ago

Closing in favor of #120