google / hypernerf

Code for "HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields".
https://hypernerf.github.io
Apache License 2.0
886 stars 106 forks source link

Model initialization step within colab using tpu and default configuration exits with error. #24

Open owenschris opened 2 years ago

owenschris commented 2 years ago

Model initialization step within colab using tpu and default configuration exits with error.

Errors are nested through jax and hypernerf, but it appear that the root is https://github.com/google/hypernerf/blob/d433ebeba4ddd91fd83aa9af3423333d2d5934e7/hypernerf/model_utils.py#L119 within the volumetric_rendering function, jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape).

The relevant error is

/usr/local/lib/python3.7/dist-packages/hypernerf/model_utils.py in volumetric_rendering(rgb, sigma, z_vals, dirs, use_white_background, sample_at_infinity, eps) 113 z_vals[..., 1:] - z_vals[..., :-1], --> 114 jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape) 115 ], -1)

/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/util.py in _broadcast_to(arr, shape) 341 return arr.broadcast_to(shape) --> 342 _check_arraylike("broadcast_to", arr) 343 arr = arr if isinstance(arr, ndarray) else _asarray(arr)

/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/util.py in _check_arraylike(fun_name, *args) 294 msg = "{} requires ndarray or scalar arguments, got {} at position {}." --> 295 raise TypeError(msg.format(fun_name, type(arg), pos)) 296

UnfilteredStackTrace: TypeError: broadcast_to requires ndarray or scalar arguments, got <class 'list'> at position 0.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.

A quick search brought up things like https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax which suggested all elements should be converted to the jnp arrays. Haven't gotten it working yet, though.

saunair commented 2 years ago

I got it to work by changing the line to: jnp.broadcast_to(jnp.array([last_sample_z]), z_vals[..., :1].shape)

corlangerak commented 2 years ago

saunair

Could you maybe share your version of the Colab notebook. I am still experiencing issues after changing the line to your suggestion... Thanks a lot for the help! :)

saunair commented 2 years ago

Hey @corlangerak, did you remove the pip install hypernerf line in the collab notebook? Because changing the line locally wouldn't reflect the change(You'll be loading the hypernerf pip install instead). I suggest

  1. Copying the notebook out of the notebook folder (basically into the project's root folder i.e. hypernerf/my_notebook.ipynb instead of hypernerf/notebooks/my_notebook.ipynb).
  2. Apply the fix I suggested, and then run the notebook/collab.
hsauod commented 2 years ago

@saunair Hi Nair, I faced the same issue. I tried to changed the line you suggested above though. Is it possible for you that giving me a little bit more specific instructions ? Thank you in advance, Nair.

saunair commented 2 years ago

@hsauod check these two changes: https://github.com/google/hypernerf/commit/ae29d1dc5824daaa59a7008df96442873017346e#diff-433be35a4beb7eeee9224dcbe28ec97d53330cd175060905cd5217863674003cR114

and check the second cell in my notebook here: https://github.com/saunair/hypernerf/blob/main/notebooks/HyperNeRF_Training.ipynb

@corlangerak here you go. Sorry about the delay

hsauod commented 2 years ago

@hsauod check these two changes: ae29d1d#diff-433be35a4beb7eeee9224dcbe28ec97d53330cd175060905cd5217863674003cR114

and check the second cell in my notebook here: https://github.com/saunair/hypernerf/blob/main/notebooks/HyperNeRF_Training.ipynb

@corlangerak here you go. Sorry about the delay

@saunair Hi Nair, thank you very much for your kindly explanation

silvercondor commented 2 years ago

Hi, incase anyone run into error, quick fix for colab is to install @saunair version of hypernerf

in cell 1 replace !pip install git+https://github.com/google/hypernerf

with !pip install git+https://github.com/saunair/hypernerf

Thanks again to @saunair for fixing error