Open GCChen97 opened 1 year ago
Even if I tried to limit the GPU ram allocation as mentioned in some issues, the script would still use 90% of the ram and crashed when rendering an image. Just can not understand that jax leaves this problem alone.
export XLA_PYTHON_CLIENT_MEM_FRACTION="0.5"
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"
What does nvidia-smi look like?
Hi, @jonbarron , this is what I am trying now.
if __name__ == '__main__':
import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.5"
# print("XLA_PYTHON_CLIENT_MEM_FRACTION=0.5")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
# Restrict TensorFlow to only allocate 1GB of memory on the first GPU
try:
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=2048)])
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
with gin.config_scope('train'):
app.run(main)
2/250000: loss=0.03680, psnr=14.793, lr=3.21e-05 | data=0.03646, orie=1.2e-05, pred=0.00033, 167 r/s
100/250000: loss=0.08784, psnr=11.866, lr=6.17e-04 | data=0.07865, orie=0.00886, pred=0.00033, 2898 r/s
Rendering chunk 0/129599
Rendering chunk 12960/129599
Rendering chunk 25920/129599
Rendering chunk 38880/129599
Rendering chunk 51840/129599
Rendering chunk 64800/129599
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02 Driver Version: 510.85.02 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | N/A |
| 54% 64C P2 220W / 350W | 11815MiB / 24576MiB | 54% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1181 G /usr/lib/xorg/Xorg 53MiB |
| 0 N/A N/A 1757 G /usr/lib/xorg/Xorg 307MiB |
| 0 N/A N/A 1893 G /usr/bin/gnome-shell 50MiB |
| 0 N/A N/A 5143 G ...nlogin/bin/sunloginclient 12MiB |
| 0 N/A N/A 32718 G /usr/lib/firefox/firefox 117MiB |
| 0 N/A N/A 1223976 G ...RendererForSitePerProcess 171MiB |
| 0 N/A N/A 1277408 C python 11081MiB |
+-----------------------------------------------------------------------------+
OOM still arised.
Rendering chunk 0/129599
Rendering chunk 12960/129599
Rendering chunk 25920/129599
Rendering chunk 38880/129599
Rendering chunk 51840/129599
Rendering chunk 64800/129599
Rendering chunk 77760/129599
Rendering chunk 90720/129599
Rendering chunk 103680/129599
Rendering chunk 116640/129599
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02 Driver Version: 510.85.02 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | N/A |
| 52% 57C P2 120W / 350W | 23919MiB / 24576MiB | 7% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1181 G /usr/lib/xorg/Xorg 53MiB |
| 0 N/A N/A 1757 G /usr/lib/xorg/Xorg 307MiB |
| 0 N/A N/A 1893 G /usr/bin/gnome-shell 50MiB |
| 0 N/A N/A 5143 G ...nlogin/bin/sunloginclient 12MiB |
| 0 N/A N/A 32718 G /usr/lib/firefox/firefox 117MiB |
| 0 N/A N/A 1223976 G ...RendererForSitePerProcess 161MiB |
| 0 N/A N/A 1277408 C python 23195MiB |
+-----------------------------------------------------------------------------+
I've tried different configuration of ram limitation. It seems that the OOM is nothing to do with ram limitation because it would crash finally.
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.7"
print("XLA_PYTHON_CLIENT_MEM_FRACTION=0.7")
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
# if gpus:
# # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
# try:
# tf.config.set_logical_device_configuration(
# gpus[0],
# [tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
# logical_gpus = tf.config.list_logical_devices('GPU')
# print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
# except RuntimeError as e:
# # Virtual devices must be set before GPUs have been initialized
# print(e)
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
with gin.config_scope('train'):
app.run(main)
I bypassed the OOM by simply transfering the chunks in render_image
to CPU.
I bypassed the OOM by simply transfering the chunks in
render_image
to CPU. Hi, @GCChen97 , I have encountered the same problem as you. Could you share your code for this step? Thanks!
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks):
chunks_new = []
device_cpu = jax.devices("cpu")[0]
for chunk in chunks:
chunk_new = {}
chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
chunk_new["distance_mean"] = \
jax.device_put(chunk["distance_mean"], device_cpu)
chunk_new["distance_median"] = \
jax.device_put(chunk["distance_median"], device_cpu)
chunk_new["distance_percentile_5"] = \
jax.device_put(chunk["distance_percentile_5"], device_cpu)
chunk_new["distance_percentile_95"] = \
jax.device_put(chunk["distance_percentile_95"], device_cpu)
chunk_new["normals"] = \
jax.device_put(chunk["normals"], device_cpu)
chunk_new["normals_pred"] = \
jax.device_put(chunk["normals_pred"], device_cpu)
chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
for data in chunk["ray_rgbs"] ]
chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
for data in chunk["ray_sdist"] ]
chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
for data in chunk["ray_weights"] ]
chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)
chunks_new.append(chunk_new)
return chunks_new
def render_image(...
...
chunks = move_chunks_to_cpu(chunks)
# Concatenate all chunks within each leaf of a single pytree.
rendering = (
jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
...
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.
"line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"
I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.
"line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"
I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
I didn't use the above code, I just reduced the 'render_chunk_size' in internal/config.py to 2048.
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error. "line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'" I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
I didn't use the above code, I just reduced the 'render_chunk_size' in internal/config.py to 2048.
Thank you, I lowered the render_chunk_size and tested...now I am getting this error:
Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED in external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc(4640): 'status'
Another thread mentions metric_harness is the problem and to just comment out but doesn't mention where. I am testing commenting out the following in train.py lines 242 - 248:
#metric = metric_harness(
# postprocess_fn(rendering['rgb']), postprocess_fn(test_case.rgb))
#print(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s')
#for name, val in metric.items():
# if not np.isnan(val):
# print(f'{name} = {val:.4f}')
# summary_writer.scalar('train_metrics/' + name, val, step)
Also line 80 in train.py:
So far this is the first time I have trained past 5000 iter. I will provide an update later.
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.
"line 686, in render_image jax.tree_util.tree_map(lambda args: jnp.concatenate(args), chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"
I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
Yes, the code is for internal/models.py
. The key is that the chunks
need to be moved to cpu as the modification code does.
Hi, I tried to run ref-nerf but no matter how small the batch size is, the OOM problem would arise. I am not sure if it is the bug of jax or multinerf or tf. I've tried jax v0.3.24/25 but got the same problem.
I used: