ChikaYan / d2nerf

Apache License 2.0
181 stars 14 forks source link

Error when training/rendering with 8 GPUs #7

Closed prajwalchidananda closed 1 year ago

prajwalchidananda commented 1 year ago
Traceback (most recent call last):
  File "train.py", line 574, in <module>
    app.run(main)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "train.py", line 495, in main
    process_iterator(tag='runtime_eval',
  File "train.py", line 528, in process_iterator
    model_out = render_fn(state, batch, rng=rng)
  File "/home/wayve/prajwal/d2nerf/hypernerf/evaluation.py", line 119, in render_image
    chunk_rays_dict = utils.shard(chunk_rays_dict, device_count)
  File "/home/wayve/prajwal/d2nerf/hypernerf/utils.py", line 289, in shard
    return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/wayve/prajwal/d2nerf/hypernerf/utils.py", line 289, in <lambda>
    return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1756, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1750, in _compute_newshape
    return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1750, in <genexpr>
    return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/core.py", line 1438, in divide_shape_sizes
    return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
  File "/home/wayve/prajwal/d2nerf/env/lib/python3.8/site-packages/jax/core.py", line 1348, in divide_shape_sizes
    raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (60, 3) and (8, -1, 3)

I encounter this error when training using 8 GPUs (specifically, the vrig_balloon). The error is raised while trying to render an image during eval. Here's the corresponding PR that fixes the bug: https://github.com/d2nerf/d2nerf/pull/8

ChikaYan commented 1 year ago

Thank you, I've merged your PR.