Closed ymd8bit closed 1 year ago
The error is that your jax installation does not match your CUDNN library. You could install the compatible jax as follows:
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Do you actually need jax? In Tetra-NeRF it was only used to provide a fair comparison with mipNeRF 360. You can safely uninstall it and the code should run well.
@jkulhanek thanks!
after pip install --upgrade "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
, my training seems properly works.
As you mentioned I actually don't need jax, I encountered this problem when running ns-train
without any change in the environment built by the Dockerfile that installs jax originally. I'm not very sure why the installed jax in the dockerfile doesn't match the version of cudnn installed in [the base container](FROM nvidia/cuda:11.7.1-devel-ubuntu22.04)
Thanks! I will fix the dockerfile to install the correct jax.
I found an error "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more" when I train blender/lego dataset with pointnerf-blender ply file.
I investigated the code by following the error log, it occurred when jax computes mipnerf_ssim in
tetranerf/nerfstudio/model.py
. I pasted the code fragment.I found some issues on web and some of them like this say " it happens when the GPU reaches its memory limit", so please tell me how to reduce memory usage during training, but I don't think training the blender dataset that is the simplest one uses such a huge memory , wandb log shows it only used 4GB arourd. Please let me know if you know about this phenomenon. I put the information of my environment and all console log below.
Environment
Console log