google / hypernerf

Code for "HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields".
https://hypernerf.github.io
Apache License 2.0
895 stars 105 forks source link

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3958, 3) and (8, -1, 3) #9

Open SangbumChoi opened 2 years ago

SangbumChoi commented 2 years ago

Hi, during demo in jupyter notebook there is a shape error occuring when the code goes into rendering.

Any clue to fix this error?

9 frames
/usr/local/lib/python3.7/dist-packages/hypernerf/evaluation.py in render_image(state, rays_dict, model_fn, device_count, rng, chunk, default_ret_key)
    114         lambda x: x[(proc_id * per_proc_rays):((proc_id + 1) * per_proc_rays)],
    115         chunk_rays_dict)
--> 116     chunk_rays_dict = utils.shard(chunk_rays_dict, device_count)
    117     model_out = model_fn(key_0, key_1, state.optimizer.target['model'],
    118                          chunk_rays_dict, state.extra_params)

/usr/local/lib/python3.7/dist-packages/hypernerf/utils.py in shard(xs, device_count)
    287   if device_count is None:
    288     jax.local_device_count()
--> 289   return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
    290 
    291 

/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest)
    176   leaves, treedef = tree_flatten(tree, is_leaf)
    177   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 178   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    179 
    180 tree_multimap = tree_map

/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py in <genexpr>(.0)
    176   leaves, treedef = tree_flatten(tree, is_leaf)
    177   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 178   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    179 
    180 tree_multimap = tree_map

/usr/local/lib/python3.7/dist-packages/hypernerf/utils.py in <lambda>(x)
    287   if device_count is None:
    288     jax.local_device_count()
--> 289   return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
    290 
    291 

/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _reshape(a, order, *args)
   1727 
   1728 def _reshape(a, *args, order="C"):
-> 1729   newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
   1730   if order == "C":
   1731     return lax.reshape(a, newshape, None)

/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _compute_newshape(a, newshape)
   1723   return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
   1724                if core.symbolic_equal_dim(d, -1) else d
-> 1725                for d in newshape)
   1726 
   1727 

/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in <genexpr>(.0)
   1723   return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
   1724                if core.symbolic_equal_dim(d, -1) else d
-> 1725                for d in newshape)
   1726 
   1727 

/usr/local/lib/python3.7/dist-packages/jax/core.py in divide_shape_sizes(s1, s2)
   1407   s2 = s2 or (1,)
   1408   handler, ds = _dim_handler_and_canonical(*s1, *s2)
-> 1409   return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
   1410 
   1411 def same_shape_sizes(s1: Shape, s2: Shape) -> bool:

/usr/local/lib/python3.7/dist-packages/jax/core.py in divide_shape_sizes(self, s1, s2)
   1322       return 1
   1323     if sz1 % sz2:
-> 1324       raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
   1325     return sz1 // sz2
   1326 

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3958, 3) and (8, -1, 3)
Spark001 commented 2 years ago

If I guess right, you used 8 gpus to run the code. An annoying alternative is change 8 to 2, making 3958 could divide 2 evenly . Hoping anyone could give a better solution :(