saharmor / dalle-playground

A playground to generate images from any text prompt using Stable Diffusion (past: using DALL-E Mini)
MIT License
2.76k stars 597 forks source link

`RESOURCE_EXHAUSTED` error when starting Mega_full #78

Open davisengeler opened 2 years ago

davisengeler commented 2 years ago

Context

I’ve got the project running nicely when using mini and Mega models with an RTX 3080ti (12GB) on Ubuntu 22.03. Results take less than 4 seconds per image when using Mega.

Problem

Despite all other models working, I’ve not been able to start a Mega_full instance. I keep getting a RESOURCE_EXHAUSTED error after starting the server (full traceback below). It’s possible this 12GB GPU simply isn’t enough, but I figured I’d report it here for feedback before dismissing it. Thanks for any recommendations!

Full Traceback

Traceback (most recent call last):
  File "/home/davis/Desktop/dalle-playground-main/backend/app.py", line 61, in <module>
    dalle_model.generate_images("warm-up", 1)
  File "/home/davis/Desktop/dalle-playground-main/backend/dalle_model.py", line 109, in generate_images
    decoded_images = p_decode(self.vqgan, encoded_images, self.vqgan_params)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/api.py", line 2026, in cache_miss
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/api.py", line 1902, in pmap_f
    out = pxla.xla_pmap(
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/core.py", line 1868, in bind
    return map_bind(self, fun, *args, **params)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/core.py", line 1900, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/core.py", line 1871, in process
    return trace.process_map(self, fun, tracers, params)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/core.py", line 678, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 790, in xla_pmap_impl
    compiled_fun, fingerprint = parallel_callable(
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/linear_util.py", line 285, in memoized_fun
    ans = call(fun, *args)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 821, in parallel_callable
    pmap_executable = pmap_computation.compile()
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1090, in compile
    self._executable = PmapExecutable.from_hlo(self._hlo, **self.compile_args)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1214, in from_hlo
    compiled = dispatch.compile_or_get_cached(
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 768, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 713, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bias-activation.245 = (f32[1,128,128,256]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,128,128,256]{2,1,3,0} %gather.8, f32[3,3,256,256]{1,0,2,3} %copy.67, f32[256]{0} %get-tuple-element.111), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="pmap(p_decode)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 128, 128, 256) rhs_shape=(3, 3, 256, 256) precision=None preferred_element_type=None]" source_file="/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/flax/linen/linear.py" source_line=358}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: RESOURCE_EXHAUSTED: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/davis/Desktop/dalle-playground-main/backend/app.py", line 61, in <module>
    dalle_model.generate_images("warm-up", 1)
  File "/home/davis/Desktop/dalle-playground-main/backend/dalle_model.py", line 109, in generate_images
    decoded_images = p_decode(self.vqgan, encoded_images, self.vqgan_params)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bias-activation.245 = (f32[1,128,128,256]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,128,128,256]{2,1,3,0} %gather.8, f32[3,3,256,256]{1,0,2,3} %copy.67, f32[256]{0} %get-tuple-element.111), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="pmap(p_decode)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 128, 128, 256) rhs_shape=(3, 3, 256, 256) precision=None preferred_element_type=None]" source_file="/home/davis/Desktop/dalle-playground-main/backend/venv/lib/python3.10/site-packages/flax/linen/linear.py" source_line=358}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: RESOURCE_EXHAUSTED: Failed to allocate request for 32.00MiB (33554432B) on device ordinal 0

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.
realies commented 2 years ago

Got in a similar loop when trying to start the back-end using a 1050ti:

dalle-backend    | You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
dalle-backend    | Traceback (most recent call last):
dalle-backend    |   File "app.py", line 60, in <module>
dalle-backend    |     dalle_model = DalleModel(args.model_version)
dalle-backend    |   File "/app/dalle_model.py", line 70, in __init__
dalle-backend    |     self.params = replicate(params)
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/flax/jax_utils.py", line 56, in replicate
dalle-backend    |     return jax.device_put_replicated(tree, devices)
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 2801, in device_put_replicated
dalle-backend    |     return tree_map(_device_put_replicated, x)
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in tree_map
dalle-backend    |     return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in <genexpr>
dalle-backend    |     return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 2796, in _device_put_replicated
dalle-backend    |     buf, = dispatch.device_put(x, devices[0])
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 871, in device_put
dalle-backend    |     return device_put_handlers[type(x)](x, device)
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 901, in _device_put_device_array
dalle-backend    |     x = _copy_device_array_to_device(x, device)
dalle-backend    |   File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 924, in _copy_device_array_to_device
dalle-backend    |     moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
dalle-backend    | jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 384.00MiB (402653184B) on device ordinal 0
dalle-backend exited with code 1
davisengeler commented 2 years ago

Are there any required/recommended GPU specs for the Mega_full model? I couldn't find anything, so I'm not sure if 12GB VRAM or less is enough?

Aeriit commented 2 years ago

Currently running the mega_full model on Manjaro with a 3090 - wasn't able to get it running on a GPU with less memory. The model seems to be using 12GB VRAM, so a card with "only" 12GB would indeed not be enough (especially if you consider that the model most likely isn't going to be the only thing on your computer using VRAM unless you aren't running a DE or something).

davisengeler commented 2 years ago

Thanks @Aeriit. I assume it won’t be possible, but are you aware of any way to work around this since I’m “so close” to the VRAM requirements?

davisengeler commented 2 years ago

Found a solution in my case

I solved my issue by setting a couple environment variables for python before starting the backend.

  1. Run these commands in terminal
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
  1. Start the backend
python3 app.py --port 8080 --model_version Mega_full

This document explains why those environment variables are helpful. I'm now able to run the backend natively using the Mega_full model on my 12GB 3080ti and generate images in ~6 seconds each.

cc @Aeriit @realies

Edit: I eventually had another RESOURCE_EXHAUSTED crash on Mega_full, but was able to test it for a bit. @Aeriit is correct that 12GB is right at the cusp, but not quite enough for stability.

realies commented 1 year ago

@davisengeler, unfortunately, that doesn't help on WSL for me.

nicklansley commented 1 year ago

I'm able to run Mega_Full on twin RTX 2080 TIs, each with 11 GB memory, and the NVIDIA library spreads the loading of the model across both GPUs until they are full. Image processing is spread across both GPUs too. So, if you have less than 12GB then adding another NVIDIA GPU that takes the total graphics memory above 12GB may work for you. I don't know whether the two GPUs have to be identical models or not.