Open dspringst opened 2 years ago
Thank you. What version of Ubuntu are you on? I was having issues trying trying these steps tonight, even with successful installations along the way. After looking at the cuDNN Support Matrix, I’m wondering if it might be because I’m on Ubuntu 22.04.
Edit: never mind, I got it working and Ubuntu 22.04 was not the problem.
I can confirm that the installation works on WSL and that it doesn't default back to CPU. However, it crashes like this:
2022-06-16 15:38:36.720344: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc:57] cuLinkAddData fails. This is usually caused by stale driver version.
2022-06-16 15:38:36.720473: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1248] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.
Traceback (most recent call last):
File "backend/app.py", line 60, in <module>
dalle_model = DalleModel(args.model_version)
File "/home/calops/dalle-playground/backend/dalle_model.py", line 62, in __init__
dalle_model, revision=DALLE_COMMIT_ID, dtype=dtype, _do_init=False
File "/home/calops/.local/lib/python3.7/site-packages/dalle_mini/model/utils.py", line 26, in from_pretrained
pretrained_model_name_or_path, *model_args, **kwargs
File "/home/calops/.local/lib/python3.7/site-packages/transformers/modeling_flax_utils.py", line 596, in from_pretrained
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
File "/home/calops/.local/lib/python3.7/site-packages/transformers/models/bart/modeling_flax_bart.py", line 920, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
File "/home/calops/.local/lib/python3.7/site-packages/transformers/modeling_flax_utils.py", line 115, in __init__
self.key = PRNGKey(seed)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/random.py", line 125, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/prng.py", line 236, in seed_with_impl
return PRNGKeyArray(impl, impl.seed(seed))
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/prng.py", line 276, in threefry_seed
lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 444, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/calops/.local/lib/python3.7/site-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/calops/.local/lib/python3.7/site-packages/jax/core.py", line 326, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/calops/.local/lib/python3.7/site-packages/jax/core.py", line 675, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 99, in apply_primitive
**params)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/util.py", line 219, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/util.py", line 212, in cached
return f(*args, **kwargs)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 149, in xla_primitive_callable
prim.name, donated_invars, False, *arg_specs)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 231, in _xla_callable_uncached
keep_unused, *arg_specs).compile().unsafe_call
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 705, in compile
self.name, self._hlo, self._explicit_args, **self.compile_args)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 806, in from_xla_computation
compiled = compile_or_get_cached(backend, xla_computation, options)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 768, in compile_or_get_cached
return backend_compile(backend, computation, compile_options)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/calops/.local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 713, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_asm_compiler.cc(60): 'status'
Had the same issue as @calops, fixed it by following this: https://github.com/google/jax/issues/5723#issuecomment-1132802697
Basically, make sure your CUDA driver version (found in nvidia-smi
) and CUDA toolkit version (found with nvcc --version
) are the same. This can happen if you're using a technically unsupported OS (like me on Fedora 36, where I'm using the driver from dnf, and the toolkit from nvidia's website).
I think this is unrelated to this issue though.
@u1f98e’s comment seems relevant to other flavors of Linux too, I believe. When I was having problems, I noticed two different versions were installed.
In general, I found the official cuDNN installation docs to be helpful, including the prerequisites. When installing CUDA Toolkit according to the docs, don’t forget the mandatory post-installation actions section. After that, use the instructions from the jax repo to install the particular version you need.
TL;DR - these steps worked for me to use GPU:
After this, installing the requirements.txt
with pip and running the backend worked for me. If it still fails, check @u1f98e’s comment above. If you have multiple versions, remove the old one and try to run the backend again.
Hello,
Going through the steps to set up a local development environment seems to be missing steps and leading people to have misconfigured environments if they want to utilize their GPU.
The main culprit that seems to be causing the greatest amount of trouble is the current step that installs jax to the user's machine is the non CUDA version.
To be able to install the CUDA version of jax, the user has to follow the installation instructions on nvidia's website to install the latest version of the CUDA toolkit, then sign up as an nVidia developer to get access to the cuDNN library, install that, then follow the instructions on the jax github page to get the cuda version of jax. This works on Ubuntu and I'm pretty sure that this is doable on WSL, but not 100% sure on that.
Hopefully this will clear up a lot of issues people are having setting up