google-research / multinerf

A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF
Apache License 2.0
3.62k stars 341 forks source link

OOM with only ~12MB memory allocated / requested on GPU #77

Open GCChen97 opened 1 year ago

GCChen97 commented 1 year ago

Hi, I tried to run ref-nerf but no matter how small the batch size is, the OOM problem would arise. I am not sure if it is the bug of jax or multinerf or tf. I've tried jax v0.3.24/25 but got the same problem.

I used:

python 3.9
ubuntu 20.04.5
RTX3090
CUDA 11.6 CUDNN 8.6
NV driver 510
 jax v0.3.24/25
flax v0.6.1/2
2022-11-29 21:04:38.780955: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.00MiB (rounded to 6291456)requested by op 
2022-11-29 21:04:39.275116: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:492] ****************************************************************************************************
2022-11-29 21:04:39.275246: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6291456 bytes.
2022-11-29 21:04:39.275246: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6291456 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    6.00MiB
              constant allocation:         0B
        maybe_live_out allocation:    6.00MiB
     preallocated temp allocation:         0B
                 total allocation:   12.00MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 6.00MiB
                Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/media/gccrcv/Data/Opensources/multinerf/internal/models.py" source_line=689
                XLA Label: concatenate
                Shape: f32[4096,128,3]
                ==========================

        Buffer 2:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 3:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 4:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 5:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 6:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 7:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 8:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 9:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 10:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 11:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 12:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 13:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 14:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 15:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

Traceback (most recent call last):
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/media/gccrcv/Data/Opensources/multinerf/train.py", line 288, in <module>
    app.run(main)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/media/gccrcv/Data/Opensources/multinerf/train.py", line 229, in main
    rendering = models.render_image(
  File "/media/gccrcv/Data/Opensources/multinerf/internal/models.py", line 689, in render_image
    jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/tree_util.py", line 207, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/tree_util.py", line 207, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/media/gccrcv/Data/Opensources/multinerf/internal/models.py", line 689, in <lambda>
    jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1791, in concatenate
    arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1791, in <listcomp>
    arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 648, in concatenate
    return concatenate_p.bind(*operands, dimension=dimension)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
    return compiled_fun(*args)
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 200, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
    out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6291456 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    6.00MiB
              constant allocation:         0B
        maybe_live_out allocation:    6.00MiB
     preallocated temp allocation:         0B
                 total allocation:   12.00MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 6.00MiB
                Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/media/gccrcv/Data/Opensources/multinerf/internal/models.py" source_line=689
                XLA Label: concatenate
                Shape: f32[4096,128,3]
                ==========================

        Buffer 2:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 3:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 4:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 5:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 6:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 7:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 8:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 9:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 10:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 11:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 12:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 13:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 14:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================

        Buffer 15:
                Size: 384.0KiB
                Entry Parameter Subshape: f32[256,128,3]
                ==========================
GCChen97 commented 1 year ago

Even if I tried to limit the GPU ram allocation as mentioned in some issues, the script would still use 90% of the ram and crashed when rendering an image. Just can not understand that jax leaves this problem alone.

export XLA_PYTHON_CLIENT_MEM_FRACTION="0.5"
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"
jonbarron commented 1 year ago

What does nvidia-smi look like?

GCChen97 commented 1 year ago

Hi, @jonbarron , this is what I am trying now.

if __name__ == '__main__':
  import os
  # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.5"
  # print("XLA_PYTHON_CLIENT_MEM_FRACTION=0.5")
  os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
  import tensorflow as tf
  gpus = tf.config.list_physical_devices('GPU')
  if gpus:
    # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
    try:
      tf.config.set_logical_device_configuration(
          gpus[0],
          [tf.config.LogicalDeviceConfiguration(memory_limit=2048)])
      logical_gpus = tf.config.list_logical_devices('GPU')
      print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
      # Virtual devices must be set before GPUs have been initialized
      print(e)
  with gin.config_scope('train'):
    app.run(main)
      2/250000: loss=0.03680, psnr=14.793, lr=3.21e-05 | data=0.03646, orie=1.2e-05, pred=0.00033, 167 r/s
    100/250000: loss=0.08784, psnr=11.866, lr=6.17e-04 | data=0.07865, orie=0.00886, pred=0.00033, 2898 r/s
Rendering chunk 0/129599
Rendering chunk 12960/129599
Rendering chunk 25920/129599
Rendering chunk 38880/129599
Rendering chunk 51840/129599
Rendering chunk 64800/129599
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| 54%   64C    P2   220W / 350W |  11815MiB / 24576MiB |     54%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1181      G   /usr/lib/xorg/Xorg                 53MiB |
|    0   N/A  N/A      1757      G   /usr/lib/xorg/Xorg                307MiB |
|    0   N/A  N/A      1893      G   /usr/bin/gnome-shell               50MiB |
|    0   N/A  N/A      5143      G   ...nlogin/bin/sunloginclient       12MiB |
|    0   N/A  N/A     32718      G   /usr/lib/firefox/firefox          117MiB |
|    0   N/A  N/A   1223976      G   ...RendererForSitePerProcess      171MiB |
|    0   N/A  N/A   1277408      C   python                          11081MiB |
+-----------------------------------------------------------------------------+
GCChen97 commented 1 year ago

OOM still arised.

Rendering chunk 0/129599
Rendering chunk 12960/129599
Rendering chunk 25920/129599
Rendering chunk 38880/129599
Rendering chunk 51840/129599
Rendering chunk 64800/129599
Rendering chunk 77760/129599
Rendering chunk 90720/129599
Rendering chunk 103680/129599
Rendering chunk 116640/129599
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| 52%   57C    P2   120W / 350W |  23919MiB / 24576MiB |      7%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1181      G   /usr/lib/xorg/Xorg                 53MiB |
|    0   N/A  N/A      1757      G   /usr/lib/xorg/Xorg                307MiB |
|    0   N/A  N/A      1893      G   /usr/bin/gnome-shell               50MiB |
|    0   N/A  N/A      5143      G   ...nlogin/bin/sunloginclient       12MiB |
|    0   N/A  N/A     32718      G   /usr/lib/firefox/firefox          117MiB |
|    0   N/A  N/A   1223976      G   ...RendererForSitePerProcess      161MiB |
|    0   N/A  N/A   1277408      C   python                          23195MiB |
+-----------------------------------------------------------------------------+
GCChen97 commented 1 year ago

I've tried different configuration of ram limitation. It seems that the OOM is nothing to do with ram limitation because it would crash finally.

  import os
  os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.7"
  print("XLA_PYTHON_CLIENT_MEM_FRACTION=0.7")
  # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
  import tensorflow as tf
  gpus = tf.config.list_physical_devices('GPU')
  # if gpus:
  #   # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
  #   try:
  #     tf.config.set_logical_device_configuration(
  #         gpus[0],
  #         [tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
  #     logical_gpus = tf.config.list_logical_devices('GPU')
  #     print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  #   except RuntimeError as e:
  #     # Virtual devices must be set before GPUs have been initialized
  #     print(e)

  if gpus:
    try:
      # Currently, memory growth needs to be the same across GPUs
      for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
      logical_gpus = tf.config.list_logical_devices('GPU')
      print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
      # Memory growth must be set before GPUs have been initialized
      print(e)

  with gin.config_scope('train'):
    app.run(main)
GCChen97 commented 1 year ago

I bypassed the OOM by simply transfering the chunks in render_image to CPU.

Riga27527 commented 1 year ago

I bypassed the OOM by simply transfering the chunks in render_image to CPU. Hi, @GCChen97 , I have encountered the same problem as you. Could you share your code for this step? Thanks!

GCChen97 commented 1 year ago

Hi, @Riga27527 , here is my solution:

def move_chunks_to_cpu(chunks):
  chunks_new = []
  device_cpu = jax.devices("cpu")[0]
  for chunk in chunks:
    chunk_new = {}
    chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
    chunk_new["distance_mean"] = \
      jax.device_put(chunk["distance_mean"], device_cpu)
    chunk_new["distance_median"] = \
      jax.device_put(chunk["distance_median"], device_cpu)
    chunk_new["distance_percentile_5"] = \
      jax.device_put(chunk["distance_percentile_5"], device_cpu)
    chunk_new["distance_percentile_95"] = \
      jax.device_put(chunk["distance_percentile_95"], device_cpu)
    chunk_new["normals"] = \
      jax.device_put(chunk["normals"], device_cpu)
    chunk_new["normals_pred"] = \
      jax.device_put(chunk["normals_pred"], device_cpu)

    chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_rgbs"] ]
    chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_sdist"] ]
    chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_weights"] ]

    chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
    chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)

    chunks_new.append(chunk_new)
  return chunks_new

def render_image(...
  ...
  chunks = move_chunks_to_cpu(chunks)
    # Concatenate all chunks within each leaf of a single pytree.
    rendering = (
        jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  ...
levelcapgg commented 1 year ago

Hi, @Riga27527 , here is my solution:

def move_chunks_to_cpu(chunks):
  chunks_new = []
  device_cpu = jax.devices("cpu")[0]
  for chunk in chunks:
    chunk_new = {}
    chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
    chunk_new["distance_mean"] = \
      jax.device_put(chunk["distance_mean"], device_cpu)
    chunk_new["distance_median"] = \
      jax.device_put(chunk["distance_median"], device_cpu)
    chunk_new["distance_percentile_5"] = \
      jax.device_put(chunk["distance_percentile_5"], device_cpu)
    chunk_new["distance_percentile_95"] = \
      jax.device_put(chunk["distance_percentile_95"], device_cpu)
    chunk_new["normals"] = \
      jax.device_put(chunk["normals"], device_cpu)
    chunk_new["normals_pred"] = \
      jax.device_put(chunk["normals_pred"], device_cpu)

    chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_rgbs"] ]
    chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_sdist"] ]
    chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_weights"] ]

    chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
    chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)

    chunks_new.append(chunk_new)
  return chunks_new

def render_image(...
  ...
  chunks = move_chunks_to_cpu(chunks)
    # Concatenate all chunks within each leaf of a single pytree.
    rendering = (
        jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  ...

Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.

"line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"

I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.

Riga27527 commented 1 year ago

Hi, @Riga27527 , here is my solution:

def move_chunks_to_cpu(chunks):
  chunks_new = []
  device_cpu = jax.devices("cpu")[0]
  for chunk in chunks:
    chunk_new = {}
    chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
    chunk_new["distance_mean"] = \
      jax.device_put(chunk["distance_mean"], device_cpu)
    chunk_new["distance_median"] = \
      jax.device_put(chunk["distance_median"], device_cpu)
    chunk_new["distance_percentile_5"] = \
      jax.device_put(chunk["distance_percentile_5"], device_cpu)
    chunk_new["distance_percentile_95"] = \
      jax.device_put(chunk["distance_percentile_95"], device_cpu)
    chunk_new["normals"] = \
      jax.device_put(chunk["normals"], device_cpu)
    chunk_new["normals_pred"] = \
      jax.device_put(chunk["normals_pred"], device_cpu)

    chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_rgbs"] ]
    chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_sdist"] ]
    chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_weights"] ]

    chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
    chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)

    chunks_new.append(chunk_new)
  return chunks_new

def render_image(...
  ...
  chunks = move_chunks_to_cpu(chunks)
    # Concatenate all chunks within each leaf of a single pytree.
    rendering = (
        jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  ...

Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.

"line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"

I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.

I didn't use the above code, I just reduced the 'render_chunk_size' in internal/config.py to 2048.

levelcapgg commented 1 year ago

Hi, @Riga27527 , here is my solution:

def move_chunks_to_cpu(chunks):
  chunks_new = []
  device_cpu = jax.devices("cpu")[0]
  for chunk in chunks:
    chunk_new = {}
    chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
    chunk_new["distance_mean"] = \
      jax.device_put(chunk["distance_mean"], device_cpu)
    chunk_new["distance_median"] = \
      jax.device_put(chunk["distance_median"], device_cpu)
    chunk_new["distance_percentile_5"] = \
      jax.device_put(chunk["distance_percentile_5"], device_cpu)
    chunk_new["distance_percentile_95"] = \
      jax.device_put(chunk["distance_percentile_95"], device_cpu)
    chunk_new["normals"] = \
      jax.device_put(chunk["normals"], device_cpu)
    chunk_new["normals_pred"] = \
      jax.device_put(chunk["normals_pred"], device_cpu)

    chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_rgbs"] ]
    chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_sdist"] ]
    chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_weights"] ]

    chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
    chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)

    chunks_new.append(chunk_new)
  return chunks_new

def render_image(...
  ...
  chunks = move_chunks_to_cpu(chunks)
    # Concatenate all chunks within each leaf of a single pytree.
    rendering = (
        jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  ...

Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error. "line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'" I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.

I didn't use the above code, I just reduced the 'render_chunk_size' in internal/config.py to 2048.

Thank you, I lowered the render_chunk_size and tested...now I am getting this error:

Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED in external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc(4640): 'status'

Another thread mentions metric_harness is the problem and to just comment out but doesn't mention where. I am testing commenting out the following in train.py lines 242 - 248:

    #metric = metric_harness(
    #    postprocess_fn(rendering['rgb']), postprocess_fn(test_case.rgb))
    #print(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s')
    #for name, val in metric.items():
    #  if not np.isnan(val):
    #    print(f'{name} = {val:.4f}')
    #    summary_writer.scalar('train_metrics/' + name, val, step)

Also line 80 in train.py:

metric_harness = image.MetricHarness()

So far this is the first time I have trained past 5000 iter. I will provide an update later.

GCChen97 commented 1 year ago

Hi, @Riga27527 , here is my solution:

def move_chunks_to_cpu(chunks):
  chunks_new = []
  device_cpu = jax.devices("cpu")[0]
  for chunk in chunks:
    chunk_new = {}
    chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
    chunk_new["distance_mean"] = \
      jax.device_put(chunk["distance_mean"], device_cpu)
    chunk_new["distance_median"] = \
      jax.device_put(chunk["distance_median"], device_cpu)
    chunk_new["distance_percentile_5"] = \
      jax.device_put(chunk["distance_percentile_5"], device_cpu)
    chunk_new["distance_percentile_95"] = \
      jax.device_put(chunk["distance_percentile_95"], device_cpu)
    chunk_new["normals"] = \
      jax.device_put(chunk["normals"], device_cpu)
    chunk_new["normals_pred"] = \
      jax.device_put(chunk["normals_pred"], device_cpu)

    chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_rgbs"] ]
    chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_sdist"] ]
    chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
      for data in chunk["ray_weights"] ]

    chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
    chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)

    chunks_new.append(chunk_new)
  return chunks_new

def render_image(...
  ...
  chunks = move_chunks_to_cpu(chunks)
    # Concatenate all chunks within each leaf of a single pytree.
    rendering = (
        jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
  ...

Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.

"line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"

I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.

Yes, the code is for internal/models.py. The key is that the chunks need to be moved to cpu as the modification code does.