Open SangbumChoi opened 2 years ago
Hi, during demo in jupyter notebook there is a shape error occuring when the code goes into rendering.
Any clue to fix this error?
9 frames /usr/local/lib/python3.7/dist-packages/hypernerf/evaluation.py in render_image(state, rays_dict, model_fn, device_count, rng, chunk, default_ret_key) 114 lambda x: x[(proc_id * per_proc_rays):((proc_id + 1) * per_proc_rays)], 115 chunk_rays_dict) --> 116 chunk_rays_dict = utils.shard(chunk_rays_dict, device_count) 117 model_out = model_fn(key_0, key_1, state.optimizer.target['model'], 118 chunk_rays_dict, state.extra_params) /usr/local/lib/python3.7/dist-packages/hypernerf/utils.py in shard(xs, device_count) 287 if device_count is None: 288 jax.local_device_count() --> 289 return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs) 290 291 /usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest) 176 leaves, treedef = tree_flatten(tree, is_leaf) 177 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 178 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 179 180 tree_multimap = tree_map /usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py in <genexpr>(.0) 176 leaves, treedef = tree_flatten(tree, is_leaf) 177 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] --> 178 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 179 180 tree_multimap = tree_map /usr/local/lib/python3.7/dist-packages/hypernerf/utils.py in <lambda>(x) 287 if device_count is None: 288 jax.local_device_count() --> 289 return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs) 290 291 /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _reshape(a, order, *args) 1727 1728 def _reshape(a, *args, order="C"): -> 1729 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) 1730 if order == "C": 1731 return lax.reshape(a, newshape, None) /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in _compute_newshape(a, newshape) 1723 return tuple(- core.divide_shape_sizes(np.shape(a), newshape) 1724 if core.symbolic_equal_dim(d, -1) else d -> 1725 for d in newshape) 1726 1727 /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py in <genexpr>(.0) 1723 return tuple(- core.divide_shape_sizes(np.shape(a), newshape) 1724 if core.symbolic_equal_dim(d, -1) else d -> 1725 for d in newshape) 1726 1727 /usr/local/lib/python3.7/dist-packages/jax/core.py in divide_shape_sizes(s1, s2) 1407 s2 = s2 or (1,) 1408 handler, ds = _dim_handler_and_canonical(*s1, *s2) -> 1409 return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):]) 1410 1411 def same_shape_sizes(s1: Shape, s2: Shape) -> bool: /usr/local/lib/python3.7/dist-packages/jax/core.py in divide_shape_sizes(self, s1, s2) 1322 return 1 1323 if sz1 % sz2: -> 1324 raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}") 1325 return sz1 // sz2 1326 InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3958, 3) and (8, -1, 3)
If I guess right, you used 8 gpus to run the code. An annoying alternative is change 8 to 2, making 3958 could divide 2 evenly . Hoping anyone could give a better solution :(
8
2
3958
Hi, during demo in jupyter notebook there is a shape error occuring when the code goes into rendering.
Any clue to fix this error?