jkulhanek / tetra-nerf

Official implementation for Tetra-NeRF paper - NeRF represented as triangulation of input point cloud.
https://jkulhanek.com/tetra-nerf
MIT License
266 stars 14 forks source link

JAX Error happened at computing mipnerf-ssim #7

Closed ymd8bit closed 1 year ago

ymd8bit commented 1 year ago

I found an error "jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more" when I train blender/lego dataset with pointnerf-blender ply file.

ns-train tetra-nerf-original --pipeline.model.tetrahedra-path /data/blender/lego/pointnerf-0.5.th blender-data --data /data/nerf_synthetic/lego

I investigated the code by following the error log, it occurred when jax computes mipnerf_ssim in tetranerf/nerfstudio/model.py. I pasted the code fragment.

    import dm_pix as pix
    import jax

    jax_ssim = jax.jit(pix.ssim)

    def mipnerf_ssim(image, rgb):
        values = [
            float(jax_ssim(gt, img)) # <---- here
            for gt, img in zip(image.cpu().permute(0, 2, 3, 1).numpy(), rgb.cpu().permute(0, 2, 3, 1).numpy())
        ]
        return sum(values) / len(values)

I found some issues on web and some of them like this say " it happens when the GPU reaches its memory limit", so please tell me how to reduce memory usage during training, but I don't think training the blender dataset that is the simplest one uses such a huge memory , wandb log shows it only used 4GB arourd. Please let me know if you know about this phenomenon. I put the information of my environment and all console log below.

Environment

Console log

[17:53:50] Saving config to: outputs/unnamed/tetra-nerf-original/2023-05-10_175350/config.yml   experiment_config.py:129
[17:53:50] Saving checkpoints to: outputs/unnamed/tetra-nerf-original/2023-05-10_175350/nerfstudio_models trainer.py:132
Setting up training dataset...
Caching all 100 images.
Setting up evaluation dataset...
Caching all 100 images.
No checkpoints to load, training from scratch
logging events to: outputs/unnamed/tetra-nerf-original/2023-05-10_175350
Tetrahedra initialized from file data/blender/lego/pointnerf-0.5.th:
    Num points: 302781
    Num tetrahedra: 1907307
[ 4][       KNOBS]: All knobs on default.

[ 4][  DISK CACHE]: Opened database: "/var/tmp/OptixCache_user/optix7cache.db"
[ 4][  DISK CACHE]:     Cache data size: "30.3 KiB"
[ 4][   DISKCACHE]: Cache hit for key: ptx-14578-keyc4b635684e4442d9c1d4b43a60a9b30c-sm_86-rtc1-drv525.105.17
[ 4][COMPILE FEEDBACK]: 
[ 4][COMPILE FEEDBACK]: Info: Pipeline has 1 module(s), 4 entry function(s), 1 trace call(s), 0 continuation callable call(s), 0 direct callable call(s), 59 basic block(s) in entry functions, 543 instruction(s) in entry functions, 8 non-entry function(s), 63 basic block(s) in non-entry functions, 811 instruction(s) in non-entry functions, no debug information

[17:53:54] Printing max of 10 lines. Set flag --logging.local-writer.max-log-size=0 to disable line        writer.py:393
           wrapping.                                                                                                    
Step (% Done)       Train Iter (time)    ETA (time)           Train Rays / Sec      
----------------------------------------------------------------------------------- 
1900 (0.63%)        85.231 ms            7 h, 3 m, 27 s       49.29 K               
1910 (0.64%)        84.933 ms            7 h, 1 m, 57 s       49.42 K               
1920 (0.64%)        84.999 ms            7 h, 2 m, 16 s       49.38 K               
1930 (0.64%)        84.602 ms            7 h, 0 m, 17 s       49.61 K               
1940 (0.65%)        84.175 ms            6 h, 58 m, 9 s       49.94 K               
1950 (0.65%)        84.363 ms            6 h, 59 m, 4 s       49.79 K               
1960 (0.65%)        84.869 ms            7 h, 1 m, 34 s       49.40 K               
1970 (0.66%)        84.570 ms            7 h, 0 m, 4 s        49.54 K               
1980 (0.66%)        84.693 ms            7 h, 0 m, 40 s       49.50 K               
1990 (0.66%)        85.338 ms            7 h, 3 m, 51 s       49.14 K               
2023-05-10 17:57:00.496176: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Printing profiling stats, from longest to shortest duration in seconds
VanillaPipeline.get_eval_loss_dict: 0.0858              
Trainer.train_iteration: 0.0847              
VanillaPipeline.get_train_loss_dict: 0.0821              
Trainer.eval_iteration: 0.0000              
Traceback (most recent call last):
  File "/usr/local/bin/ns-train", line 8, in <module>
    sys.exit(entrypoint())
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 247, in entrypoint
    main(
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 233, in main
    launch(
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 172, in launch
    main_func(local_rank=0, world_size=world_size, config=config)
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 87, in train_loop
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/engine/trainer.py", line 267, in train
    self.eval_iteration(step)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/utils/decorators.py", line 70, in wrapper
    ret = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/utils/profiler.py", line 43, in wrapper
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/engine/trainer.py", line 447, in eval_iteration
    metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/utils/profiler.py", line 43, in wrapper
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/pipelines/base_pipeline.py", line 330, in get_eval_image_metrics_and_images
    metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
  File "/workspaces/tetra-nerf/tetranerf/nerfstudio/model.py", line 608, in get_image_metrics_and_images
    metrics_dict["mipnerf_ssim"] = float(mipnerf_ssim(image, rgb))
  File "/workspaces/tetra-nerf/tetranerf/nerfstudio/model.py", line 47, in mipnerf_ssim
    values = [
  File "/workspaces/tetra-nerf/tetranerf/nerfstudio/model.py", line 48, in <listcomp>
    float(jax_ssim(gt, img))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 208, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 2633, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 790, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 1088, in _pjit_call_impl
    always_lower=False, lowering_platform=None).compile()
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 462, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/bin/ns-train", line 8, in <module>
    sys.exit(entrypoint())
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 247, in entrypoint
    main(
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 233, in main
    launch(
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 172, in launch
    main_func(local_rank=0, world_size=world_size, config=config)
  File "/usr/local/lib/python3.10/dist-packages/scripts/train.py", line 87, in train_loop
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/engine/trainer.py", line 267, in train
    self.eval_iteration(step)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/utils/decorators.py", line 70, in wrapper
    ret = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/utils/profiler.py", line 43, in wrapper
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/engine/trainer.py", line 447, in eval_iteration
    metrics_dict, images_dict = self.pipeline.get_eval_image_metrics_and_images(step=step)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/utils/profiler.py", line 43, in wrapper
    ret = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/nerfstudio/pipelines/base_pipeline.py", line 330, in get_eval_image_metrics_and_images
    metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
  File "/workspaces/tetra-nerf/tetranerf/nerfstudio/model.py", line 608, in get_image_metrics_and_images
    metrics_dict["mipnerf_ssim"] = float(mipnerf_ssim(image, rgb))
  File "/workspaces/tetra-nerf/tetranerf/nerfstudio/model.py", line 47, in mipnerf_ssim
    values = [
  File "/workspaces/tetra-nerf/tetranerf/nerfstudio/model.py", line 48, in <listcomp>
    float(jax_ssim(gt, img))
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
wandb: Waiting for W&B process to finish... (failed 1). Press Control-C to abort syncing.
wandb: 
wandb: Run history:
wandb:               ETA (time) █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                Eval Loss ▁
wandb:  Eval Loss Dict/rgb_loss ▁
wandb:          GPU Memory (MB) ▁▄▄▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇█████████████
wandb:        Train Iter (time) █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:               Train Loss █▇▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: Train Loss Dict/rgb_loss █▇▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:         Train Rays / Sec ▁▆▇▆▆▆▆▇▆▅▅▇▅▇▆▇▇▅▇▅█▅▇▇▅▅▄▄▅▆▅▆▅▄▅▄▄▆▅▅
wandb:     learning_rate/fields ███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
wandb: 
wandb: Run summary:
wandb:               ETA (time) 25305.57381
wandb:                Eval Loss 0.00438
wandb:  Eval Loss Dict/rgb_loss 0.00438
wandb:          GPU Memory (MB) 4239.46826
wandb:        Train Iter (time) 0.08492
wandb:               Train Loss 0.00203
wandb: Train Loss Dict/rgb_loss 0.00203
wandb:         Train Rays / Sec 49336.53799
wandb:     learning_rate/fields 0.00098
wandb: 
jkulhanek commented 1 year ago

The error is that your jax installation does not match your CUDNN library. You could install the compatible jax as follows:

pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Do you actually need jax? In Tetra-NeRF it was only used to provide a fair comparison with mipNeRF 360. You can safely uninstall it and the code should run well.

ymd8bit commented 1 year ago

@jkulhanek thanks! after pip install --upgrade "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, my training seems properly works. As you mentioned I actually don't need jax, I encountered this problem when running ns-train without any change in the environment built by the Dockerfile that installs jax originally. I'm not very sure why the installed jax in the dockerfile doesn't match the version of cudnn installed in [the base container](FROM nvidia/cuda:11.7.1-devel-ubuntu22.04)

jkulhanek commented 1 year ago

Thanks! I will fix the dockerfile to install the correct jax.