google / mipnerf

Apache License 2.0
894 stars 109 forks source link

Batch_size Can't reduce GPU Memory #26

Open BianFeiHu opened 2 years ago

BianFeiHu commented 2 years ago

Hi, I am using RTX3080 for training and will crash every 5000 iterations when executing this code vis_suite = vis.visualize_suite(pred_distance, pred_acc) And here is the error message

Traceback (most recent call last):
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/feihu/mipnerf-main/train.py", line 321, in <module>
    app.run(main)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/data/feihu/mipnerf-main/train.py", line 295, in main
    vis_suite = vis.visualize_suite(pred_distance, pred_acc)
  File "/data/feihu/mipnerf-main/internal/vis.py", line 140, in visualize_suite
    'depth_normals': visualize_normals(depth, acc)
  File "/data/feihu/mipnerf-main/internal/vis.py", line 125, in visualize_normals
    normals = depth_to_normals(scaled_depth)
  File "/data/feihu/mipnerf-main/internal/vis.py", line 38, in depth_to_normals
    dy = convolve2d(depth, f_blur[None, :] * f_edge[:, None])
  File "/data/feihu/mipnerf-main/internal/vis.py", line 30, in convolve2d
    return jsp.signal.convolve2d(
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/scipy/signal.py", line 85, in convolve2d
    return _convolve_nd(in1, in2, mode, precision=precision)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/scipy/signal.py", line 65, in _convolve_nd
    result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides,
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/lax/convolution.py", line 147, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/core.py", line 326, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/core.py", line 675, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 98, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/util.py", line 219, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/util.py", line 212, in cached
    return f(*args, **kwargs)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 148, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 704, in compile
    self._executable = XlaCompiledComputation.from_xla_computation(
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 806, in from_xla_computation
    compiled = compile_or_get_cached(backend, xla_computation, options)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 768, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 713, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,1,800,800]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,800,800]{3,2,1,0} %Arg_0.1, f32[1,1,3,3]{3,2,1,0} %Arg_1.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 1, 800, 800) rhs_shape=(1, 1, 3, 3) precision=(<Precision.HIGHEST: 2>, <Precision.HIGHEST: 2>) preferred_element_type=None]" source_file="/data/feihu/mipnerf-main/internal/vis.py" source_line=30}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

I have found that jax will show this message when OOM, so i changed my batch_size from 1024 to 512, but it still takes 10GB when training, how can I reduce the usage of GPU Memory?

jonbarron commented 2 years ago

It looks like the OOM is for visualizing the normals of the rendered depth map, which is probably something you don't need. I'd just delete that line and not visualize that component.

BianFeiHu commented 2 years ago

It looks like the OOM is for visualizing the normals of the rendered depth map, which is probably something you don't need. I'd just delete that line and not visualize that component.

Thanks for your reply, I deleted this code but encounter another OOM problem when testing ssim = ssim_fn(pred_color, test_case['pixels'])

Traceback (most recent call last):
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/feihu/mipnerf-main/train.py", line 321, in <module>
    app.run(main)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/feihu/.conda/envs/metanerf/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/data/feihu/mipnerf-main/train.py", line 300, in main
    ssim = ssim_fn(pred_color, test_case['pixels'])
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[3,1,800,790]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,1,800,800]{3,2,1,0} %bitcast.3, f32[1,1,1,11]{3,2,1,0} %bitcast.5), window={size=1x11}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)) feature_group_count=1 batch_group_count=1 lhs_shape=(3, 1, 800, 800) rhs_shape=(1, 1, 1, 11) precision=(<Precision.HIGHEST: 2>, <Precision.HIGHEST: 2>) preferred_element_type=None]" source_file="/data/feihu/mipnerf-main/internal/math.py" source_line=93}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

I need ssim score so I can't just delete it. Will there exists a better solution?

hdzmtsssw commented 2 years ago

Hi, I face the same issue. Did you find a solution? @BianFeiHu

colinpeng-datascience commented 2 years ago

Same here

colinpeng-datascience commented 2 years ago

Hi! I managed to solve this problem by the method below.

https://github.com/google/jax/issues/4920#issuecomment-1161227026