Open ankitatiisc opened 2 years ago
Modifying: https://github.com/google/hypernerf/blob/main/hypernerf/model_utils.py#L114
From:
jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
To:
import numpy as np ... jnp.broadcast_to(np.asarray([last_sample_z]), z_vals[..., :1].shape)
Fixed the problem for me. (jax version fixed to jax==0.2.20 jaxlib==0.1.71)
Hi authors I am getting this error while running train.py script. I have not made any changes to the code. If I may ask, what was your build setup and configuration ? Also, is this code specific only for a jax version or does it have a forward compatibility with new version of jax. ?