Alescontrela / viper_rl

Using advances in generative modeling to learn reward functions from unlabeled videos.
MIT License
103 stars 11 forks source link

FAILED_PRECONDITION: DNN library initialization failed. #8

Open Benjamin-So opened 1 month ago

Benjamin-So commented 1 month ago

Description: I'm encountering an error when attempting to run the sample Policy RL script. It seems to be related to utilizing CUDA on a GPU. The error message indicates a failure to initialize the DNN library.

UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Environment:

OS: Ubuntu 22.04 CUDA Version: 11.8 JAX Version: 0.4.13 JAXlib Version: 0.4.13+cuda11.cudnn86 Python Version: Python 3.8.19

Steps to set up

Attempts to narrow issue

Stack Trace ...

github_1 github_2 github_3 github_4 github_5

NonsansWD commented 1 month ago

Hey, I literally have the exact same problem and I was unable to solve it for days so if you find or get any solution please post it here. I dont know if you have tried this but what I found out is that jax is trying to use a lot of VRAM when running the first operation so you can try for yourself if you have any luck with limiting the allocation by either doing "export XLA_PYTHON_CLIENT_PREALLOCATE=false" to completely stop it from preallocating or "export XLA_PYTHON_CLIENT_MEM_FRACTION=.XX" with XX being for example 50 to give it like 50% of the VRAM but in my case it did not solve the issue but maybe you have more luck. Besides that im now trying to make stuff run on python 3.9 since jax works on my python 3.9 environment. But this is really painful and doesnt seem to be easy. Im afraid nobody will really solve this issue here but i can suggest you to look into the JAX github and if you find a solution to it there within the issues list please tell me cause i was unable to find a solution to it yet and i even posted an issue there myself but nobody answers it so maybe you have more luck on that. Sorry thats all i can give you

Alescontrela commented 1 month ago

At ICML this week so haven't had much time to look at issues. This is not a Torch library, installing both Jax and Torch is maybe causing some strange dependency issues. Additionally, the error:

DNN library initialization failed. Look at the errors above for more details.

Looks like a cuDNN error to me. It doesn't seem that Jax can find it on your system. The instructions for installing CUDA in the README are outdated. Try the recommended install: pip install -U "jax[cuda12]" or pip install -U "jax[cuda11]" depending on your version of CUDA.

NonsansWD commented 1 month ago

At ICML this week so haven't had much time to look at issues. This is not a Torch library, installing both Jax and Torch is maybe causing some strange dependency issues. Additionally, the error:

DNN library initialization failed. Look at the errors above for more details.

Looks like a cuDNN error to me. It doesn't seem that Jax can find it on your system. The instructions for installing CUDA in the README are outdated. Try the recommended install: pip install -U "jax[cuda12]" or pip install -U "jax[cuda11]" depending on your version of CUDA.

Hey I can already tell you that this is an issue with JAX where it runs into problems when trying the cudnn initialization but interestingly when im trying to use jax with any python version higher than 3.8 it works completely fine its just not working within Python 3.8 for some reason. Unfortunately the issue i posted about this in the JAX repo is being completely ignored and there seems to be no real solution provided yet unless i just did not find it. I think in Python 3.8 the installation was a little different as I think i ran into some issue that the extra "cuda12" does not exist so the installation I had to go through looked something like this: pip install jax==0.4.13 jaxlib==0.4.13+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html which basically installs jaxlib with cuda12 and cudnn 8.9 support but this ran into the issue mentioned by OP. With the recommended commands from you you get this problem: "WARNING: jax 0.4.13 does not provide the extra 'cuda12'" so this is not an option as the extra does not exist for the latest jax version usable with Python 3.8. I would be really thankful for any further help with this as I can not get it to run. Maybe you could create a new conda environment yourself to test it if that is possible. Its unfortunately really important for my work that this works but i cannot get it to work. Im currently even trying to make the code run with Python 3.9 because jax works there but thats connected to a lot of issues i am trying to fix so its not looking too good.

NonsansWD commented 1 month ago

Small update: there is no way to use Python 3.9 because the advantage of having higher versions of jax brings incompatibilities with functions that were removed by higher jax version so i guess that would only be possible to fix by doing some major changes in the viper source code.

Benjamin-So commented 1 month ago

Thank you for your comments @Alescontrela @NonsansWD. I managed to solve the problem on my end, but I'm not sure if the solution will apply to you. It's worth giving a try. I set up the environment as I mention above. If you run pip show jaxlib you should get an output specifying "Version: 0.4.13+cuda11.cudnn86". I was installing cudnn incorrectly. I had to replace /usr/local/cuda/ with a different path (I needed to install on a supercloud and within a conda environment). If you're installing this locally, the following should resolve the issue.

If you want to install it within a conda environment, you can identify your cuda path using which nvcc. The output will be something like your_cuda_path/some_conda_environment/bin/nvcc. So you would replace /usr/local/cuda/ with your_cuda_path/some_conda_environment.

Then export the path to the conda environment to your LD_LIBRARY_PATH. I hope that helps!

NonsansWD commented 1 month ago

Thank you for your comments @Alescontrela @NonsansWD. I managed to solve the problem on my end, but I'm not sure if the solution will apply to you. It's worth giving a try. I set up the environment as I mention above. If you run pip show jaxlib you should get an output specifying "Version: 0.4.13+cuda11.cudnn86". I was installing cudnn incorrectly. I had to replace /usr/local/cuda/ with a different path (I needed to install on a supercloud and within a conda environment). If you're installing this locally, the following should resolve the issue.

If you want to install it within a conda environment, you can identify your cuda path using which nvcc. The output will be something like your_cuda_path/some_conda_environment/bin/nvcc. So you would replace /usr/local/cuda/ with your_cuda_path/some_conda_environment.

Then export the path to the conda environment to your LD_LIBRARY_PATH. I hope that helps!

You won't believe it but before looking at this comment I managed to solve it too the exact same way xD. For me it was a little different because I built the matching cudnn version on arch but that way it worked. It's also fine to use cudnn 8.9 with cuda 12 as thats what im using now and on another system I'm using cudnn 8.6 with cuda 11 now and it works there too so the only problem is basically the major version update to 9.x for cudnn which cant be supported as jax cannot support it. Also got that hint from a recent reply from today on my issue on the jax github and for the pointers i already have Symlinks in place and bashrc prepared so it works with any environment out of the box but thank you very much for sharing this and I hope anyone who stumbles upon this issue will see our comments and either of those approaches will work for them too. Have a nice time running the code. I hope it wont take too long cause it wants to do like a million steps and starts at 40,000 :D

NonsansWD commented 1 month ago

Thank you for your comments @Alescontrela @NonsansWD. I managed to solve the problem on my end, but I'm not sure if the solution will apply to you. It's worth giving a try. I set up the environment as I mention above. If you run pip show jaxlib you should get an output specifying "Version: 0.4.13+cuda11.cudnn86". I was installing cudnn incorrectly. I had to replace /usr/local/cuda/ with a different path (I needed to install on a supercloud and within a conda environment). If you're installing this locally, the following should resolve the issue.

If you want to install it within a conda environment, you can identify your cuda path using which nvcc. The output will be something like your_cuda_path/some_conda_environment/bin/nvcc. So you would replace /usr/local/cuda/ with your_cuda_path/some_conda_environment.

Then export the path to the conda environment to your LD_LIBRARY_PATH. I hope that helps!

@Benjamin-So Quick question tho, Since i fixed it it constantly ran and after 5 days i got to like 380k do you experience something similar or is it signifcantly faster for you? Its still way faster than doing it on cpu but it still takes a long time. Also are the loss curves and images also as bad as mine as described in my recently posted issue?