google / hypernerf

Code for "HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields".
https://hypernerf.github.io
Apache License 2.0
896 stars 105 forks source link

Issue doing the test training run - "flax.errors.InvalidRngError" #2

Open kurtjcu opened 3 years ago

kurtjcu commented 3 years ago

Hi, Thanks for the code. :) I have been unable to run the example code due to the following error.

python train.py         --base_folder $EXPERIMENT_PATH         --gin_bindings="data_dir='$DATASET_PATH'"         --gin_configs configs/test_local.gin
2021-10-16 22:37:32.827267: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2021-10-16 22:37:32.827288: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1835] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I1016 22:37:32.828192 140583580991872 train.py:135] *** Starting experiment
I1016 22:37:32.828281 140583580991872 train.py:139] *** Loading Gin configs from: ['configs/test_local.gin']
I1016 22:37:32.840747 140583580991872 xla_bridge.py:231] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I1016 22:37:32.949953 140583580991872 xla_bridge.py:231] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I1016 22:37:32.950088 140583580991872 train.py:159] exp_dir = /disk2/nerf-review/hypernerf/test01
I1016 22:37:32.950088 140583580991872 train.py:159] exp_dir = /disk2/nerf-review/hypernerf/test01
I1016 22:37:32.950333 140583580991872 train.py:163] summary_dir = /disk2/nerf-review/hypernerf/test01/summaries/train
I1016 22:37:32.950404 140583580991872 train.py:167] checkpoint_dir = /disk2/nerf-review/hypernerf/test01/checkpoints
I1016 22:37:32.950486 140583580991872 train.py:171] Starting process 0. There are 1 processes.
I1016 22:37:32.950564 140583580991872 train.py:173] Found 2 accelerator devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)].
I1016 22:37:32.950626 140583580991872 train.py:175] Found 2 total devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)].
I1016 22:37:33.311031 140583580991872 train.py:187] Creating datasource
I1016 22:37:33.311344 140583580991872 nerfies.py:72] *** Loading dataset IDs from /disk2/nerf-review/datasets/hypernerf/achen_000000_nerfies_04/capture_upright_01/dataset.json
I1016 22:37:33.312014 140583580991872 core.py:237] Creating datasource of type NerfiesDataSource with use_appearance_id=True, use_camera_id=False, use_warp_id=True, use_depth=False, use_time=False, train_stride=1, val_stride=1
I1016 22:37:33.312376 140583580991872 train.py:200] Initializing models.
Traceback (most recent call last):
  File "train.py", line 370, in <module>
    app.run(main)
  File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "train.py", line 203, in main
    model, params['model'] = models.construct_nerf(
  File "/disk2/nerf-review/hypernerf/hypernerf/models.py", line 701, in construct_nerf
    params = model.init({
  File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 998, in init
    _, v_out = self.init_with_output(
  File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 968, in init_with_output
    return self.apply(
  File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/linen/module.py", line 936, in apply
    return apply(
  File "/home/kurt/anaconda3/envs/hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 686, in wrapper
    with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
  File "/home/kurt/anaconda3/envs//hypernerf/lib/python3.8/site-packages/flax/core/scope.py", line 663, in bind
    raise errors.InvalidRngError(
flax.errors.InvalidRngError: rngs should be a dictionary mapping strings to `jax.PRNGKey`. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.InvalidRngError)

installation instructions followed except for "jaxlib==0.1.71+cuda111" being changed to "jaxlib==0.1.71" to allow installation (followed by jax installation from github). Dataset has been previously trained with Nerfies successfully.

taoranyi commented 3 years ago

"pip install jax==0.2.20" may solve your problem

SuwoongHeo commented 2 years ago

This error seems related to a version mismatch between jax and flax.

ChikaYan commented 2 years ago

"pip install jax==0.2.20" may solve your problem

Thx! Installing jax==0.2.20 successfully resolves the flax.errors.InvalidRngError.

Sadly, it now causes segmentation fault raised by line 330 of train.py. I'm running on a single GPU with 47GB memory and I've tried setting batch_size = 64, NerfModel.num_coarse_samples = 16, NerfModel.num_fine_samples = 16 in configs/test_local.gin and the segfault still happens, so I suppose it shouldn't be an issue with insufficient memory?

nahyup commented 2 years ago

I think that jaxlib version may cause the segfault. I suffered same issue.

"pip install jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html" may solve your problem

ChikaYan commented 2 years ago

I think that jaxlib version may cause the segfault. I suffered same issue.

"pip install jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html" may solve your problem

Thx a lot! The issue has been resolved.

Zvyozdo4ka commented 9 months ago

@ChikaYan @SuwoongHeo @taoranyi after training did you get points.npy ? Could you give any hints how to get it?