google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
327 stars 45 forks source link

Fix SDPA FX pass for mask input #85

Closed Linchenn closed 3 months ago

Linchenn commented 3 months ago

BUG=NONE

haozha111 commented 3 months ago

let's add a unit test?

chunnienc commented 3 months ago

let's add a unit test?

The old implementation will not run into failing corner cases with the current setup. This change is just to make it more robust to future changes when we extend the definition of a zero tensor e.g. not just coming from aten.zeros but also a tensor input with all zeros.

Talked to Lin about the lifting behavior of composite builder - unmarked model input would still be lifted as composite input may break this pass. In general this is a more robust implementation to always replace input zero tensor with a constant tesnor.