Open Abhishek-Varma opened 11 months ago
(I’m not sure how to provided input for the SCALAR here, tried providing a dummy scalar value, 1xf16, 0xf16 but none worked)
--input=f16=0.0
works for me.
@antiagainst @MaheshRavishankar I confirm that issue persists in latest SRT and I get 500+ second (reported via iree-benchmark-module
).
This doesn't occur for 20231006.543
SRT version and we get 1.5-2+ seconds and it generates correct output.
Compile (--iree-flow-break-dispatch=forward_dispatch_156_
):
iree-compile --iree-llvmcpu-target-cpu-features=host --iree-hal-cuda-llvm-target-arch=sm_89 --iree-stream-resource-max-allocation-size=4294967295 --iree-vm-bytecode-module-strip-source-map=true --iree-util-zero-fill-elided-attrs --iree-opt-strip-assertions=true --iree-hal-target-backends=cuda --iree-preprocessing-pass-pipeline="builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))" .\unet_1_77_1024_1024_fp16_stable-diffusion-xl-base-1.mlir -o .\unet_cuda.vmfb --iree-flow-break-dispatch=forward_dispatch_156_
Run:
iree-benchmark-module --device_allocator=caching --module=unet_cuda.vmfb --device=cuda --function=forward --input=2x4x128x128xf16 --input=1xf16 --input=2x77x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=f16=2.0
Summary:
Breaking at dispatch 100 => 0.161 ms
Breaking at dispatch 156 => 5629 ms
Breaking at dispatch 242 => 16618 ms
Breaking at dispatch 607 => 94552 ms
...
Full model => 543691 ms
Nice. Could you upload the source for forward_dispatch_156_
?
Sure - here is forward_dispatch156.
What happened?
For UNET model, the execution is extremely slow (~500/200+ sec in various backend).
This ticket provides context for RTX 4090.
Steps to reproduce your issue
(I’m not sure how to provided input for the SCALAR here, tried providing a dummy scalar value, 1xf16, 0xf16 but none worked) So, I made the following script as a repro/observable via SHARK after compiling:-
from shark.shark_inference import SharkInference
shark_model = get_vmfb_from_path("unet_cuda.vmfb", "cuda", mlir_dialect="tm_tensor") latents = torch.ones(2,4,128,128).to(torch.float16) timesteps = torch.ones(1).to(torch.float16) prompt_embeds = torch.ones(2,77,2048).to(torch.float16) text_embeds = torch.ones(2,1280).to(torch.float16) time_ids = torch.ones(2,6).to(torch.float16) guidance_scale = torch.tensor(1).to(torch.float16) inputs = (latents, timesteps, prompt_embeds, text_embeds, time_ids, guidance_scale,) output_shark = shark_model("forward", inputs) print(output_shark)
iree-compiler 20231013.550 iree-runtime 20231013.550