iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.59k stars 581 forks source link

RTX 4090 UNET SDXL runtime execution is SLOW #15174

Open Abhishek-Varma opened 11 months ago

Abhishek-Varma commented 11 months ago

What happened?

For UNET model, the execution is extremely slow (~500/200+ sec in various backend).

This ticket provides context for RTX 4090.

Steps to reproduce your issue

  1. Download UNET model
  2. Compile command:
    iree-compile --iree-llvmcpu-target-cpu-features=host --iree-hal-cuda-llvm-target-arch=sm_89 --iree-stream-resource-max-allocation-size=4294967295 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-opt-strip-assertions=true --iree-hal-target-backends=cuda --iree-preprocessing-pass-pipeline=“builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))” .\unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1.mlir -o .\unet_cuda.vmfb
  3. Runtime command:
    iree-run-module --device_allocator=caching --module=unet_cuda.vmfb --device=cuda --function=forward --input=2x4x128x128xf16 --input=1xf16 --input=2x77x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=<A SCALAR>

    (I’m not sure how to provided input for the SCALAR here, tried providing a dummy scalar value, 1xf16, 0xf16 but none worked) So, I made the following script as a repro/observable via SHARK after compiling:-

    
    from apps.language_models.utils import (
    get_vmfb_from_path,
    )
    import torch

from shark.shark_inference import SharkInference

shark_model = get_vmfb_from_path("unet_cuda.vmfb", "cuda", mlir_dialect="tm_tensor") latents = torch.ones(2,4,128,128).to(torch.float16) timesteps = torch.ones(1).to(torch.float16) prompt_embeds = torch.ones(2,77,2048).to(torch.float16) text_embeds = torch.ones(2,1280).to(torch.float16) time_ids = torch.ones(2,6).to(torch.float16) guidance_scale = torch.tensor(1).to(torch.float16) inputs = (latents, timesteps, prompt_embeds, text_embeds, time_ids, guidance_scale,) output_shark = shark_model("forward", inputs) print(output_shark)


### What component(s) does this issue relate to?

_No response_

### Version information

SRT version :

iree-compiler 20231013.550 iree-runtime 20231013.550



### Additional context

Doesn't happen for `20231006.543`.

Here is the elided version too - [UNET elided](https://storage.googleapis.com/shark_tank/SDXL/mlir/unet_elided.mlir).
antiagainst commented 11 months ago

(I’m not sure how to provided input for the SCALAR here, tried providing a dummy scalar value, 1xf16, 0xf16 but none worked)

--input=f16=0.0 works for me.

Abhishek-Varma commented 11 months ago

@antiagainst @MaheshRavishankar I confirm that issue persists in latest SRT and I get 500+ second (reported via iree-benchmark-module). This doesn't occur for 20231006.543 SRT version and we get 1.5-2+ seconds and it generates correct output.

Compile (--iree-flow-break-dispatch=forward_dispatch_156_):

iree-compile --iree-llvmcpu-target-cpu-features=host --iree-hal-cuda-llvm-target-arch=sm_89 --iree-stream-resource-max-allocation-size=4294967295 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-opt-strip-assertions=true --iree-hal-target-backends=cuda --iree-preprocessing-pass-pipeline="builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))" .\unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1.mlir -o .\unet_cuda.vmfb --iree-flow-break-dispatch=forward_dispatch_156_

Run:

iree-benchmark-module --device_allocator=caching --module=unet_cuda.vmfb --device=cuda --function=forward --input=2x4x128x128xf16 --input=1xf16 --input=2x77x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=f16=2.0

Summary:

Breaking at dispatch 100 => 0.161 ms
Breaking at dispatch 156 => 5629 ms
Breaking at dispatch 242 => 16618 ms
Breaking at dispatch 607 => 94552 ms
...
Full model => 543691 ms
antiagainst commented 11 months ago

Nice. Could you upload the source for forward_dispatch_156_?

Abhishek-Varma commented 11 months ago

Sure - here is forward_dispatch156.