iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 614 forks source link

(gfx1103/Windows) Numerics issues on HIP driver for SDXL Unet #17579

Open monorimet opened 5 months ago

monorimet commented 5 months ago

What happened?

The same .vmfb gives different results on ROCM and HIP hal drivers. Caching allocator is being used on both, but this doesn't seem to make a difference if disabled.

Good numerics: ROCM with inlined weights gives correct output.

Bad numerics # 1: ROCM with external weights gives all zeroes output

Bad numerics # 2: HIP with inlined weights gives wrong numbers.

Bad numerics # 3: HIP with external weights gives wrong numbers.

I am filing this issue specifically for this target and IR because other targets and models do not reproduce the same success/failure cases. (see https://github.com/iree-org/iree/issues/17033)

The only reason I am including ROCM HAL results is because they contain the only success mode. We should focus on fixing HIP hal issues.

Full log output using turbine-models scripts -- I will provide iree CLI reproducers as well, but these are using fixed random inputs:

(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet_runner.py --precision=fp16 --device=hip --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103 --vmfb_path=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb                                                  
TURBINE OUTPUT: [[[[ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   ...
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]]

  [[-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   ...
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]]

  [[ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   ...
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]]

  [[-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   ...
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]]]] (1, 4, 128, 128) float16
(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet_runner.py --precision=fp16 --device=rocm --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103 --vmfb_path=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb
TURBINE OUTPUT: [[[[-0.6895  -0.3262   0.9443  ... -0.9014   1.218    0.2766 ]
   [-0.1412   0.2     -0.4927  ...  0.627   -1.163    0.4128 ]
   [-0.01245  0.633    0.1671  ... -0.05722 -0.0687  -0.10736]
   ...
   [-0.1569  -0.1216   0.325   ...  0.3892   0.7476   0.06064]
   [-0.4585   0.2944  -0.9595  ...  0.797    0.2452   0.1302 ]
   [-0.02382  1.318   -0.2832  ... -0.4692   1.057   -1.516  ]]

  [[ 0.03044 -0.4458   0.836   ...  0.2281   0.4502  -0.0377 ]
   [ 0.0486  -0.4158  -0.2251  ... -0.4724   0.4004   1.592  ]
   [ 0.92    -0.573    0.2286  ...  0.81     0.01987  0.398  ]
   ...
   [-0.5947  -1.238   -0.05618 ...  0.1353   0.0868  -0.2744 ]
   [-0.1533  -0.291    0.1362  ...  0.1338   0.1406   0.9385 ]
   [ 0.03732  1.064    1.513   ...  0.3914  -0.6694   0.699  ]]

  [[ 0.881   -0.3994   0.763   ...  0.339    0.7397  -0.295  ]
   [ 0.615    0.203   -0.7407  ... -0.1326  -0.0328   0.147  ]
   [ 0.733   -0.1461   0.1094  ...  0.44    -1.463    0.8037 ]
   ...
   [ 0.2075   0.4565   0.7773  ... -0.3655   0.1267  -0.02698]
   [ 0.4185   0.218   -0.297   ...  0.478   -1.067    1.498  ]
   [ 1.348   -0.2026  -0.1068  ...  0.044   -0.05292 -0.163  ]]

  [[ 0.7183   0.2141  -0.743   ...  1.86    -2.348   -0.1821 ]
   [-0.165   -0.10284 -0.02016 ... -0.3496   0.9595   1.24   ]
   [ 0.1755  -0.0325   1.589   ...  0.3892  -1.476    0.8857 ]
   ...
   [ 0.2554   0.6816   0.2898  ... -0.2029   0.3306   0.3394 ]
   [-1.031   -0.042    0.5566  ...  0.4116   1.478    0.01047]
   [ 0.575    1.234   -1.045   ... -0.8857   0.8745  -0.1274 ]]]] (1, 4, 128, 128) float16    
(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet.py --precision=fp16 --device=rocm --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103                                                   
C:\Users\eagarvey\SHARK\SHARK\shark.venv\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
C:\Users\eagarvey\SHARK\SHARK\shark.venv\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
C:\Users\eagarvey\SHARK\SHARK\shark.venv\Lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Compiling to rocm with flags: ['--iree-hal-target-backends=rocm', '--iree-rocm-target-chip=gfx1103', '--iree-vm-bytecode-module-output-format=flatbuffer-binary', '--iree-global-opt-propagate-transposes=true', '--iree-opt-outer-dim-concat=true', '--iree-vm-target-truncate-unsupported-floats', '--iree-llvmgpu-enable-prefetch=true', '--iree-opt-data-tiling=false', '--iree-opt-const-eval=false', '--iree-opt-aggressively-propagate-transposes=true', '--iree-flow-enable-aggressive-fusion', '--iree-global-opt-enable-fuse-horizontal-contractions=true', '--iree-codegen-gpu-native-math-precision=true', '--iree-codegen-llvmgpu-use-vector-distribution=true', '--iree-codegen-llvmgpu-enable-transform-dialect-jit=false', '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))', '--iree-codegen-transform-dialect-library=attention_and_matmul_spec_wmma.mlir']
Saved to stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.mlir
Saved to stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb
(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet_runner.py --precision=fp16 --device=hip --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103 --vmfb_path=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb --external_weight_path=scheduled_unet.safetensors
TURBINE OUTPUT: [[[[0.334    0.273    0.807    ... 0.2734   0.649    0.7837  ]
   [0.0825   0.8296   0.7188   ... 0.659    0.229    0.1401  ]
   [0.2563   0.9097   0.84     ... 0.166    0.2485   0.251   ]
   ...
   [0.6636   0.4502   0.718    ... 0.986    0.9766   0.5557  ]
   [0.2847   0.2104   0.1401   ... 0.68     0.5996   0.5386  ]
   [0.5967   0.753    0.3506   ... 0.3213   0.743    0.0962  ]]

  [[0.853    0.274    0.7505   ... 0.463    0.627    0.7515  ]
   [0.1572   0.626    0.6274   ... 0.845    0.9517   0.774   ]
   [0.555    0.3218   0.3975   ... 0.8486   0.3613   0.393   ]
   ...
   [0.169    0.05664  0.477    ... 0.3018   0.3994   0.5474  ]
   [0.3916   0.2373   0.4214   ... 0.5234   0.836    0.9893  ]
   [0.6196   0.802    0.8315   ... 0.708    0.3242   0.898   ]]

  [[0.513    0.002441 0.2339   ... 0.0674   0.9834   0.6255  ]
   [0.6973   0.6025   0.3115   ... 0.7783   0.9077   0.09814 ]
   [0.4058   0.8477   0.658    ... 0.2798   0.01807  0.04834 ]
   ...
   [0.6553   0.3813   0.8765   ... 0.536    0.4678   0.02344 ]
   [0.7695   0.764    0.8647   ... 0.313    0.007324 0.921   ]
   [0.8804   0.05664  0.4668   ... 0.2632   0.1309   0.1758  ]]

  [[0.5986   0.0801   0.31     ... 0.6123   0.1484   0.0947  ]
   [0.11816  0.9585   0.796    ... 0.169    0.69     0.1992  ]
   [0.1831   0.552    0.9834   ... 0.0801   0.02051  0.4287  ]
   ...
   [0.9175   0.773    0.9463   ... 0.9404   0.835    0.1401  ]
   [0.5796   0.1968   0.5195   ... 0.3672   0.9507   0.57    ]
   [0.2344   0.531    0.0547   ... 0.0703   0.9297   0.8394  ]]]] (1, 4, 128, 128) float16 

Steps to reproduce your issue

Artifacts:

MLIR (FP16): https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.mlir MLIR (FP32): https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp32_unet_cpu.mlir WMMA spec: https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/attention_and_matmul_spec_wmma.mlir MLIR (inlined, fp16): https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_inline.mlir

inputs: https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input1.npy https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input2.npy https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input3.npy https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input4.npy https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input5.npy https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input6.npy

Weights: https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/scheduled_unet.irpa

Compile:

iree-compile --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx1103 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-global-opt-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-vm-target-truncate-unsupported-floats --iree-llvmgpu-enable-prefetch=true --iree-opt-data-tiling=false --iree-opt-const-eval=false --iree-opt-aggressively-propagate-transposes=true --iree-flow-enable-aggressive-fusion --iree-global-opt-enable-fuse-horizontal-contractions=true --iree-codegen-gpu-native-math-precision=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))' --iree-codegen-transform-dialect-library=attention_and_matmul_spec_wmma.mlir stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.mlir -o stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb

Run:

iree-run-module --module=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp32_unet_cpu.vmfb --device_allocator=caching --parameters=model=scheduled_unet.irpa --input=@input1.npy --input=@input2.npy --input=@input3.npy --input=@input4.npy --input=@input5.npy --input=@input6.npy --device=hip

What component(s) does this issue relate to?

No response

Version information

IREE branch uses is shared/tresleches-united, but these issues historically reproduce on main branch, though all compile options here may not translate.

https://github.com/iree-org/iree/commit/c66ae1957bdb2d8dd20ef3d32e4a3ab715e87869 for exact commit.

Additional context

No response

monorimet commented 5 months ago

I tried a few different configurations, and found a potentially useful runtime error when using inlined weights with SDXL:

Assertion failed: !!(iree_hal_resource_is(base_value, &iree_hal_rocm_buffer_vtable)), file C:\V\iree\experimental\rocm\rocm_buffer.c, line 25

That is what I get when running with HIP hal driver; using the ROCM driver works and gives same numerics as with externalized weights.

Could this be related to how we are using --iree-stream-resource-memory-model=unified by default? I am trying with this flag set to discrete now.

benvanik commented 5 months ago

maybe you have it the other way around? the error you have says C:\V\iree\experimental\rocm\rocm_buffer.c which is ROCM, not HIP

that kind of error will happen if the driver is casting a buffer pointer instead of the iree_hal_allocated_buffer() result

monorimet commented 5 months ago

I was a bit confused by this as well, but this was for sure run with HIP driver. Will validate with cli

benvanik commented 5 months ago

if you're in a release LTO build it's possible the two functions are identical and got folded, but usually asserts and stuff prevent that - either way, good to test with a breakpoint or printf

monorimet commented 5 months ago

OK, so if I switch from my local build, configured with:

cmake -GNinja -B ../iree-build --log-level=VERBOSE -DIREE_BUILD_PYTHON_BINDINGS=ON -DIREE_BUILD_COMPILER=ON -DPython3_EXECUTABLE=C:\\V\SHARK-Turbine\turb.env\Scripts\python.exe -DCMAKE_BUILD_TYPE=Release -DIREE_HAL_DRIVER_VULKAN=ON -DIREE_HAL_DRIVER_CUDA=OFF -DIREE_EXTERNAL_HAL_DRIVERS="rocm" -DIREE_ENABLE_CPUINFO=ON -DIREE_HAL_DRIVER_ROCM=ON -DIREE_ENABLE_LLD=ON -DIREE_ENABLE_RUNTIME_TRACING=OFF -DIREE_ENABLE_ASSERTIONS=ON -DIREE_ENABLE_SPLIT_DWARF=ON

to a recent pip install of iree-runtime, instead of giving an assertion on hip hal driver, it just starts completely freezing my system for minutes at a time. Will try with resnet again to see if it completes. This seems to happen with --iree-stream-resource-memory-model=unified and --iree-stream-resource-memory-model=discrete but I've only tried this with externalized weights. Will try with inlined.

monorimet commented 5 months ago

Are the pip releases built with assertions disabled? It could explain this, if the driver is still casting the wrong pointer.

AWoloszyn commented 4 months ago

Does the workaround from here https://github.com/iree-org/iree/issues/17033 work to solve this issue as well?