google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
278 stars 36 forks source link

Call CrossAttention __init__ with batch dim #83

Closed talumbau closed 1 month ago

talumbau commented 2 months ago

Fixes T5 model inference

BUG=https://buganizer.corp.google.com/issues/352073514