nod-ai / SHARK

SHARK - High Performance Machine Learning Distribution
Apache License 2.0
1.4k stars 169 forks source link

Issue with VAE decode on Windows #2113

Open gpetters-amd opened 2 months ago

gpetters-amd commented 2 months ago

I'm getting runtime errors with VAE decode on Vulkan. It works fine on Linux, but running it on Windows I get the following:

Traceback (most recent call last):
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\queueing.py", line 495, in call_prediction
    output = await route_utils.call_process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\route_utils.py", line 235, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\blocks.py", line 1627, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\blocks.py", line 1185, in call_function
    prediction = await utils.async_iteration(iterator)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\utils.py", line 514, in async_iteration
    return await iterator.__anext__()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\utils.py", line 507, in __anext__
    return await anyio.to_thread.run_sync(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\anyio\_backends\_asyncio.py", line 2144, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\anyio\_backends\_asyncio.py", line 851, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\utils.py", line 490, in run_sync_iterator_async
    return next(iterator)
           ^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\gradio\utils.py", line 673, in gen_wrapper
    response = next(iterator)
               ^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\api\sd.py", line 441, in shark_sd_fn_dict_input
    generated_imgs = yield from shark_sd_fn(**sd_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\api\sd.py", line 568, in shark_sd_fn
    out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\api\sd.py", line 416, in generate_images
    imgs = self.decode_latents(
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\api\sd.py", line 337, in decode_latents
    images = self.run("vae_decode", latents_numpy).to_host()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\modules\pipeline.py", line 194, in run
    return self.iree_module_dict[submodel]["vmfb"]["main"](*inp)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\iree\runtime\function.py", line 127, in __call__
    self._invoke(arg_list, ret_list)
  File "C:\Users\gpetters\Desktop\SHARK\shark.venv\Lib\site-packages\iree\runtime\function.py", line 147, in _invoke
    self._vm_context.invoke(self._vm_function, arg_list, ret_list)
RuntimeError: Error invoking function: D:\a\iree\iree\c\runtime\src\iree\hal\drivers\vulkan\direct_command_queue.cc:114: UNKNOWN; VkResult=4294967283; while invoking native function hal.device.queue.dealloca; while calling import;
[ 2]   native hal.device.queue.dealloca:0 -
[ 1] bytecode compiled_vae.main$async:32826 C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\vae_decode.torch.tempfile:250:3
[ 0] bytecode compiled_vae.main:62 C:\Users\gpetters\Desktop\SHARK\apps\shark_studio\web\shark_tmp\vae_decode.torch.tempfile:250:3

I'm sure it's a runtime issue, since I've copied over working VMFBs from Linux and I get the same error. There've been similar issues in the past, but they came from hal.device.queue.create and were due to VRAM limitations.