Open BianFeiHu opened 2 years ago
It looks like the OOM is for visualizing the normals of the rendered depth map, which is probably something you don't need. I'd just delete that line and not visualize that component.
It looks like the OOM is for visualizing the normals of the rendered depth map, which is probably something you don't need. I'd just delete that line and not visualize that component.
Thanks for your reply, I deleted this code but encounter another OOM problem when testing
ssim = ssim_fn(pred_color, test_case['pixels'])
Traceback (most recent call last):
File "/home/feihu/.conda/envs/metanerf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/feihu/.conda/envs/metanerf/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/data/feihu/mipnerf-main/train.py", line 321, in <module>
app.run(main)
File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/data/feihu/mipnerf-main/train.py", line 300, in main
ssim = ssim_fn(pred_color, test_case['pixels'])
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[3,1,800,790]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,1,800,800]{3,2,1,0} %bitcast.3, f32[1,1,1,11]{3,2,1,0} %bitcast.5), window={size=1x11}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(3, 1, 800, 800) rhs_shape=(1, 1, 1, 11) precision=(<Precision.HIGHEST: 2>, <Precision.HIGHEST: 2>) preferred_element_type=None]" source_file="/data/feihu/mipnerf-main/internal/math.py" source_line=93}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
Original error: UNIMPLEMENTED: DNN library is not found.
I need ssim score so I can't just delete it. Will there exists a better solution?
Hi, I face the same issue. Did you find a solution? @BianFeiHu
Same here
Hi! I managed to solve this problem by the method below.
https://github.com/google/jax/issues/4920#issuecomment-1161227026
Hi, I am using RTX3080 for training and will crash every 5000 iterations when executing this code
vis_suite = vis.visualize_suite(pred_distance, pred_acc)
And here is the error messageI have found that jax will show this message when OOM, so i changed my batch_size from 1024 to 512, but it still takes 10GB when training, how can I reduce the usage of GPU Memory?