google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
199 stars 23 forks source link

Remove 5D tensor reshape in attention layer implementation. #57

Closed yichunk closed 3 weeks ago

yichunk commented 3 weeks ago

Remove 5D tensor reshape in attention layer implementation.

test on TinyLlama result, which matchs the result before removing.

BUG=b/347108294

haozha111 commented 3 weeks ago

Great, have you tested if it works for the other models (phi2 and gemma)? we need to check if this change is correct for all types of attentions (MHA, MQA, GQA). Thanks!

yichunk commented 3 weeks ago

Great, have you tested if it works for the other models (phi2 and gemma)? we need to check if this change is correct for all types of attentions (MHA, MQA, GQA). Thanks!

TinyLlama is GQA, so it should cover for MHA and MQA. Let me test on phi2 and gemma to confirm.

yichunk commented 3 weeks ago

Great, have you tested if it works for the other models (phi2 and gemma)? we need to check if this change is correct for all types of attentions (MHA, MQA, GQA). Thanks!

TinyLlama is GQA, so it should cover for MHA and MQA. Let me test on phi2 and gemma to confirm.

Test on phi2 and Gemma, and the results match the ones before.