iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.48k stars 553 forks source link

SDXL Unet ROCM runtime issues #15603

Open Abhishek-Varma opened 7 months ago

Abhishek-Varma commented 7 months ago

What happened?

There are 2 issues to resolve here for the input IR : unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1.mlir

Compiling on ROCM Linux gfx90 (--iree-flow-break-dispatch=forward_dispatch323):

iree-compile --iree-llvmcpu-target-cpu-features=host --iree-rocm-link-bc=true --iree-stream-resource-max-allocation-size=4294967295 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-opt-strip-assertions=true --verify=false --iree-rocm-link-bc=true --iree-rocm-target-chip=gfx90a --iree-preprocessing-pass-pipeline="builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-pad-linalg-ops{pad-size=32}))" unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1.mlir -o unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1_rocm.vmfb --iree-hal-target-backends=rocm --iree-flow-break-dispatch=forward_dispatch_323_

Run using the following:

iree-run-module --device_allocator=caching --module=unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1_rocm.vmfb --device=rocm --function=forward --input=2x4x128x128xf16 --input=1xf16 --input=2x77x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=2.0

Section A Causes the following issue when using iree-run-module :-

  1. For dispatches < 323 :
    EXEC @forward
    c/runtime/src/iree/vm/list.c:825: FAILED_PRECONDITION; invoking function 'forward'
  2. For dispatch >= 323 :
    c/experimental/rocm/status_util.c:31: INTERNAL; rocm driver error 'hipErrorOutOfMemory' (2): out of memory; failed to allocate buffer of length 840293696; while invoking native function hal.allocator.allocate; while calling import; 
    [ 1]   native hal.allocator.allocate:0 -
    [ 0] bytecode module@1:21574 -; creating VM context; creating run context

Section B Causes the following issue when using a Python script to invoke the forward function :-

  1. For dispatches < 323 : NO ISSUE.
  2. For dispatches >= 323 (Same as Section A):
    c/experimental/rocm/status_util.c:31: INTERNAL; rocm driver error 'hipErrorOutOfMemory' (2): out of memory; failed to allocate buffer of length 840293696; while invoking native function hal.allocator.allocate; while calling import; 
    [ 1]   native hal.allocator.allocate:0 -
    [ 0] bytecode module@1:21574 -; creating VM context; creating run context

So, effectively, there'd be two issues that needs to be resolved as captured entirely by Section A itself.

Steps to reproduce your issue

For Section A issues above, the compilation command and the run command given earlier would be needed.

For Section B issue above, the following script would be required :-

from apps.language_models.utils import (
    get_vmfb_from_path,
)
import torch

from shark.shark_inference import SharkInference

shark_model = get_vmfb_from_path("unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1_rocm.vmfb", "rocm", mlir_dialect="tm_tensor")
latents = torch.ones(2,4,128,128).to(torch.float16)
timesteps = torch.ones(1).to(torch.float16)
prompt_embeds = torch.ones(2,77,2048).to(torch.float16)
text_embeds = torch.ones(2,1280).to(torch.float16)
time_ids = torch.ones(2,6).to(torch.float16)
guidance_scale = torch.tensor(1).to(torch.float16)
inputs = (latents, timesteps, prompt_embeds, text_embeds, time_ids, guidance_scale,)
output_shark = shark_model("forward", inputs)
print(output_shark)

In case the elided IR is needed : unet_elided.mlir

I'm also attaching dispatch 323 here : dispatch_323.mlir

What component(s) does this issue relate to?

No response

Version information

I'm using the following SRT version :

iree-compiler             20231106.574
iree-runtime              20231106.574

Additional context

No response

Groverkss commented 7 months ago

Issue A is because you are passing the input wrong. Instead of --input=2.0 it should be --input=f16=2.0.

vivekkhandelwal1 commented 7 months ago

Hi @nithinsubbiah, any updates on this issue? I'm also getting the same error.

Found ROCm device arch : gfx90a
Saved vmfb in /app/phaneesh/SHARK/falcon_180b_layer_0_20_int4_rocm.vmfb.
Saved falcon vmfb at  /app/phaneesh/SHARK/falcon_180b_layer_0_20_int4_rocm.vmfb
Loading module /app/phaneesh/SHARK/falcon_180b_layer_0_20_int4_rocm.vmfb...
Traceback (most recent call last):
  File "/app/phaneesh/SHARK/apps/language_models/src/pipelines/falcon_pipeline.py", line 1097, in <module>
    falcon = ShardedFalcon(
             ^^^^^^^^^^^^^^
  File "/app/phaneesh/SHARK/apps/language_models/src/pipelines/falcon_pipeline.py", line 145, in __init__
    self.shark_model = self.compile()
                       ^^^^^^^^^^^^^^
  File "/app/phaneesh/SHARK/apps/language_models/src/pipelines/falcon_pipeline.py", line 386, in compile
    shark_module, device_idx = self.compile_layer(
                               ^^^^^^^^^^^^^^^^^^^
  File "/app/phaneesh/SHARK/apps/language_models/src/pipelines/falcon_pipeline.py", line 295, in compile_layer
    shark_module.load_module(path)
  File "/app/phaneesh/SHARK/shark/shark_inference.py", line 232, in load_module
    params = load_flatbuffer(
             ^^^^^^^^^^^^^^^^
  File "/app/phaneesh/SHARK/shark/iree_utils/compile_utils.py", line 519, in load_flatbuffer
    vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/app/phaneesh/SHARK/shark/iree_utils/compile_utils.py", line 450, in load_vmfb_using_mmap
    ctx.add_vm_module(mmaped_vmfb)
  File "/app/phaneesh/SHARK/shark.venv/lib/python3.11/site-packages/iree/runtime/system_api.py", line 271, in add_vm_module
    self.add_vm_modules((vm_module,))
  File "/app/phaneesh/SHARK/shark.venv/lib/python3.11/site-packages/iree/runtime/system_api.py", line 268, in add_vm_modules
    self._vm_context.register_modules(vm_modules)
RuntimeError: Error registering modules: c/experimental/rocm/status_util.c:31: INTERNAL; rocm driver error 'hipErrorOutOfMemory' (2): out of memory; failed to allocate buffer of length 4025827328; while invoking native function hal.allocator.allocate; while calling import;
[ 1]   native hal.allocator.allocate:0 -
[ 0] bytecode module@1:4294 -
nithinsubbiah commented 7 months ago

Hi Vivek, I'm still working on this. The error seems to be because the device is actually running out of memory and not a ROCm HAL driver error which would mean we need to shard/quantize the model. I'll investigate further and let you know.

Could you please share the IR that fails?

vivekkhandelwal1 commented 7 months ago

Hi Vivek, I'm still working on this. The error seems to be because the device is actually running out of memory and not a ROCm HAL driver error which would mean we need to shard/quantize the model. I'll investigate further and let you know.

Could you please share the IR that fails?

Yeah, you're correct. The error was actually because of OOM. Some other processes were running that I didn't know of. Anyway, thanks!

nithinsubbiah commented 7 months ago

@Abhishek-Varma I am able to run SDXL on 6010 (gfx90) successfully. I looked at the trace, dispatches and couldn't find anything offending. It's likely that multiple processes were running at the same time causing the OOM error. Please let me know if it works for you.