nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

SDXL UNET Numerics - SDPA Op results mismatch with pytorch results #507

Open PhaneeshB opened 8 months ago

PhaneeshB commented 8 months ago

After conv numerics issue #498 got fixed with this fix We see an error in the output of the first SDPA Op. when running UNet

The numerical error for all the three inputs (from the beginning and on the same inputs) is 0.01% Numerical error after SDPA is ~83% comparing with the output of pytorch fp16/fp32

IREE-compile command

tools/iree-compile  --iree-hal-target-backends=rocm \
--iree-rocm-target-chip=gfx940 \
--iree-rocm-link-bc=true \
--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode \
--iree-opt-strip-assertions=true --verify=false \
--iree-vm-bytecode-module-strip-source-map=true \
--iree-vm-target-truncate-unsupported-floats \
--iree-hal-dump-executable-files-to=haldump \
--iree-flow-dump-dispatch-graph \
--iree-global-opt-propagate-transposes=true \
--iree-opt-outer-dim-concat=true \
--iree-opt-const-eval=false \
--iree-codegen-gpu-native-math-precision=true \
--iree-rocm-waves-per-eu=2 \
--iree-preprocessing-pass-pipeline="builtin.module(iree-preprocessing-transpose-convolution-pipeline)" \
--iree-codegen-transform-dialect-library=/home/pbarwari/attention_mfma_transform_64_spec.mlir sdpaonly_f16.mlir -o sdpaonly_f16.vmfb

IREE-run command

tools/iree-run-module 2_fx_importer_module_f16_hacked.vmfb --module=sdpaonly_f16.vmfb --device=rocm --function=forward --input=@input_transpose0_2x10x4096x64_f16.npy --input=@input_transpose1_2x10x4096x64_f16.npy --input=@input_transpose2_2x10x4096x64_f16.npy --output=@output_sdpa_2x10x4096x64_f16.npy                                                       

artefacts - containing the inputs + expected_output (pytorch f16) all_artefacts.zip sdpa_fp16.mlir

PhaneeshB commented 8 months ago

attention_mfma_transform_64_spec.txt