saharmor / dalle-playground

A playground to generate images from any text prompt using Stable Diffusion (past: using DALL-E Mini)
MIT License
2.76k stars 596 forks source link

Local Development guide leads to misconfigured environment when using a GPU #52

Open dspringst opened 2 years ago

dspringst commented 2 years ago

Hello,

Going through the steps to set up a local development environment seems to be missing steps and leading people to have misconfigured environments if they want to utilize their GPU.

The main culprit that seems to be causing the greatest amount of trouble is the current step that installs jax to the user's machine is the non CUDA version.

To be able to install the CUDA version of jax, the user has to follow the installation instructions on nvidia's website to install the latest version of the CUDA toolkit, then sign up as an nVidia developer to get access to the cuDNN library, install that, then follow the instructions on the jax github page to get the cuda version of jax. This works on Ubuntu and I'm pretty sure that this is doable on WSL, but not 100% sure on that.

Hopefully this will clear up a lot of issues people are having setting up

davisengeler commented 2 years ago

Thank you. What version of Ubuntu are you on? I was having issues trying trying these steps tonight, even with successful installations along the way. After looking at the cuDNN Support Matrix, I’m wondering if it might be because I’m on Ubuntu 22.04.

Edit: never mind, I got it working and Ubuntu 22.04 was not the problem.

calops commented 2 years ago

I can confirm that the installation works on WSL and that it doesn't default back to CPU. However, it crashes like this:

2022-06-16 15:38:36.720344: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc:57] cuLinkAddData fails. This is usually caused by stale driver version.
2022-06-16 15:38:36.720473: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1248] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.
Traceback (most recent call last):
  File "backend/app.py", line 60, in <module>
    dalle_model = DalleModel(args.model_version)
  File "/home/calops/dalle-playground/backend/dalle_model.py", line 62, in __init__
    dalle_model, revision=DALLE_COMMIT_ID, dtype=dtype, _do_init=False
  File "/home/calops/.local/lib/python3.7/site-packages/dalle_mini/model/utils.py", line 26, in from_pretrained
    pretrained_model_name_or_path, *model_args, **kwargs
  File "/home/calops/.local/lib/python3.7/site-packages/transformers/modeling_flax_utils.py", line 596, in from_pretrained
    model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
  File "/home/calops/.local/lib/python3.7/site-packages/transformers/models/bart/modeling_flax_bart.py", line 920, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  File "/home/calops/.local/lib/python3.7/site-packages/transformers/modeling_flax_utils.py", line 115, in __init__
    self.key = PRNGKey(seed)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/random.py", line 125, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/prng.py", line 236, in seed_with_impl
    return PRNGKeyArray(impl, impl.seed(seed))
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/prng.py", line 276, in threefry_seed
    lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 444, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/core.py", line 675, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 99, in apply_primitive
    **params)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 149, in xla_primitive_callable
    prim.name, donated_invars, False, *arg_specs)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 231, in _xla_callable_uncached
    keep_unused, *arg_specs).compile().unsafe_call
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 705, in compile
    self.name, self._hlo, self._explicit_args, **self.compile_args)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 806, in from_xla_computation
    compiled = compile_or_get_cached(backend, xla_computation, options)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 768, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 713, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc(60): 'status'
u1f98e commented 2 years ago

Had the same issue as @calops, fixed it by following this: https://github.com/google/jax/issues/5723#issuecomment-1132802697

Basically, make sure your CUDA driver version (found in nvidia-smi) and CUDA toolkit version (found with nvcc --version) are the same. This can happen if you're using a technically unsupported OS (like me on Fedora 36, where I'm using the driver from dnf, and the toolkit from nvidia's website). I think this is unrelated to this issue though.

davisengeler commented 2 years ago

@u1f98e’s comment seems relevant to other flavors of Linux too, I believe. When I was having problems, I noticed two different versions were installed.

In general, I found the official cuDNN installation docs to be helpful, including the prerequisites. When installing CUDA Toolkit according to the docs, don’t forget the mandatory post-installation actions section. After that, use the instructions from the jax repo to install the particular version you need.

TL;DR - these steps worked for me to use GPU:

  1. Install NVIDIA graphics drivers
  2. Install CUDA Toolkit (don’t forget post-install steps)
  3. Install zlib (prereq for cuDNN)
  4. Download cuDNN and install it
  5. Install jax for CUDA

After this, installing the requirements.txt with pip and running the backend worked for me. If it still fails, check @u1f98e’s comment above. If you have multiple versions, remove the old one and try to run the backend again.