After conv numerics issue #498 got fixed with this fix
We see an error in the output of the first SDPA Op. when running UNet
The numerical error for all the three inputs (from the beginning and on the same inputs) is 0.01%
Numerical error after SDPA is ~83%
comparing with the output of pytorch fp16/fp32
After conv numerics issue #498 got fixed with this fix We see an error in the output of the first SDPA Op. when running UNet
The numerical error for all the three inputs (from the beginning and on the same inputs) is 0.01% Numerical error after SDPA is ~83% comparing with the output of pytorch fp16/fp32
IREE-compile command
IREE-run command
artefacts - containing the inputs + expected_output (pytorch f16) all_artefacts.zip sdpa_fp16.mlir