Traceback (most recent call last):
File "train.py", line 574, in <module>
app.run(main)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "train.py", line 495, in main
process_iterator(tag='runtime_eval',
File "train.py", line 528, in process_iterator
model_out = render_fn(state, batch, rng=rng)
File "/home/wayve/prajwal/d2nerf/hypernerf/evaluation.py", line 119, in render_image
chunk_rays_dict = utils.shard(chunk_rays_dict, device_count)
File "/home/wayve/prajwal/d2nerf/hypernerf/utils.py", line 289, in shard
return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/wayve/prajwal/d2nerf/hypernerf/utils.py", line 289, in <lambda>
return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1756, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1750, in _compute_newshape
return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1750, in <genexpr>
return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/core.py", line 1438, in divide_shape_sizes
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/core.py", line 1348, in divide_shape_sizes
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (60, 3) and (8, -1, 3)
I encounter this error when training using 8 GPUs (specifically, the vrig_balloon). The error is raised while trying to render an image during eval.
Here's the corresponding PR that fixes the bug:
https://github.com/d2nerf/d2nerf/pull/8
I encounter this error when training using 8 GPUs (specifically, the vrig_balloon). The error is raised while trying to render an image during eval. Here's the corresponding PR that fixes the bug: https://github.com/d2nerf/d2nerf/pull/8