facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.54k stars 1.28k forks source link

Very slow rendering for large mesh #1759

Open andrewcaunes opened 4 months ago

andrewcaunes commented 4 months ago

Hello, thank you for your work !

Problem I am trying to render many (1024*2048) images from a large mesh. My mesh has ~5M vertices and ~10M faces. The problem is that the rendering is very slow (>30 secs per image) when I think it should be way below 1 sec. My mesh is large (400 Mb) but I know that it should be much faster than this since I can easily visualize it and render it in real time using open3D. I do not care about differentiation here. I have researched all other issues mentionning slow running and tried their solutions but they didn't solve my problem.

Code I ran : The following code simply load my mesh, create a rasterizer with optimized settings and runs the rasterization. I do not include the rendering since I don't really need it yet and the rasterization is the step that seems to be problematic. The 'rasterizer(..)' line takes more than 3 secs on my RTX 3070.

with torch.no_grad():
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
            torch.cuda.set_device(device)
        else:
            device = torch.device("cpu")
        # Get mesh
        io = IO()
        if isinstance(mesh, str):
            logging.info('Trying to read %s as a mesh.', mesh)
            if mesh.endswith('.ply'):
                mesh = io.load_mesh(path=mesh, device=device)
        else:
            mesh = mesh.to(device)
        meshes = mesh.extend(1)

        # Create cameras
        cam_poses = np.load(cam_poses_path)
        R = cam_poses[0, :3, :3]
        T = cam_poses[0, :3, 3]
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

        # Create renderer
        logging.info('Creating rasterizer.')
        raster_settings = RasterizationSettings(
            image_size=(1024, 2048), 
            blur_radius=0.0, 
            faces_per_pixel=1, 
            bin_size=None
        )

        rasterizer = MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
        )

        # Render
        logging.info('Rendering.')
        fragments = rasterizer(meshes, cameras=cameras)
        pix_to_face = fragments.pix_to_face
bottler commented 4 months ago

I think your best option in pytorch3d would be to try MeshRasterizerOpenGL which should be much faster.

andrewcaunes commented 4 months ago

Thank you for the answer, I tried to use MeshRasterizerOpenGL but got this error :

Traceback (most recent call last):
  File "/home/andrew/jean-zay/projects/ia4markings/scripts/seg_3D_by_2D/proj_pytorch3D.py", line 191, in <module>
    main(args)
  File "/home/andrew/jean-zay/projects/ia4markings/scripts/seg_3D_by_2D/proj_pytorch3D.py", line 49, in main
    run_project_3D_to_2D(mesh=args.input_mesh_path, 
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/jean-zay/projects/ia4markings/scripts/seg_3D_by_2D/proj_pytorch3D.py", line 143, in run_project_3D_to_2D
    fragments = rasterizer(meshes, cameras=cameras)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/rasterizer_opengl.py", line 205, in forward
    pix_to_face, bary_coords, zbuf = self.opengl_machinery(
                                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/rasterizer_opengl.py", line 292, in __call__
    pix_to_face, bary_coord, zbuf = self._rasterize_mesh(
                                    ^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/rasterizer_opengl.py", line 443, in _rasterize_mesh
    _torch_to_opengl(face_verts, self.cuda_context, self.cuda_buffer)
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 439, in _torch_to_opengl
    cuda_copy(False)
pycuda._driver.LogicError: cuMemcpy2DUnaligned failed: invalid argument
PyCUDA WARNING: a clean-up operation failed (dead context maybe?)
cuGraphicsUnmapResources failed: invalid OpenGL or DirectX context
PyCUDA WARNING: a clean-up operation failed (dead context maybe?)
cuGraphicsUnregisterResource failed: invalid OpenGL or DirectX context
-------------------------------------------------------------------
PyCUDA ERROR: The context stack was not empty upon module cleanup.
-------------------------------------------------------------------
A context was still active when the context stack was being
cleaned up. At this point in our execution, CUDA may already
have been deinitialized, so there is no way we can finish
cleanly. The program will be aborted now.
Use Context.pop() to avoid this problem.
-------------------------------------------------------------------
Aborted (core dumped)

Have you ever met it with using opengl please ?

bottler commented 4 months ago

I'm sorry I haven't, and I don't know enough to help. Can you run the pytorch3d unit tests and see if the opengl ones work?

andrewcaunes commented 4 months ago

Thank you very much for the help !

I tried python3 -m unittest tests.test_opengl_utils -v

and got :

test_cuda_context (tests.test_opengl_utils.TestDeviceContextStore.test_cuda_context) ... here
ERROR
test_egl_context (tests.test_opengl_utils.TestDeviceContextStore.test_egl_context) ... ERROR
test_multiple_renders_single_gpu_context_store (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_multiple_renders_single_gpu_context_store) ... /home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py:44: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /home/conda/feedstock_root/build_artifacts/libtorch_1706726118919/work/torch/csrc/utils/tensor_new.cpp:1508.)
  image = torch.frombuffer(out_buffer, dtype=torch.uint8).reshape(
ok
test_multiple_renders_single_gpu_single_context (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_multiple_renders_single_gpu_single_context) ... ok
test_render_multi_thread_multi_gpu (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_multi_thread_multi_gpu) ... In thread 0, device 0.
In thread 1, device 0.
In thread 2, device 0.
In thread 3, device 0.
In thread 4, device 0.
In thread 5, device 0.
In thread 6, device 0.
In thread 7, device 0.
In thread 8, device 0.
In thread 9, device 0.
ok
test_render_two_threads_single_gpu (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_two_threads_single_gpu) ... ok
test_render_two_threads_single_gpu_context_store (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_two_threads_single_gpu_context_store) ... ok
test_render_two_threads_two_gpus (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_two_threads_two_gpus) ... Exception in thread Thread-16 (_draw_squares_with_context):
Traceback (most recent call last):
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 53, in _draw_squares_with_context
    context = EGLContext(MAX_EGL_WIDTH, MAX_EGL_HEIGHT, cuda_device_id)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 187, in __init__
    self.device = _get_cuda_device(self.cuda_device_id)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 121, in _get_cuda_device
    raise ValueError(
ValueError: Found 4 CUDA devices, but none with CUDA id 1.
ERROR
test_render_two_threads_two_gpus_context_store (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_two_threads_two_gpus_context_store) ... Exception in thread Thread-18 (_draw_squares_with_context_store):
Traceback (most recent call last):
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 72, in _draw_squares_with_context_store
    context = global_device_context_store.get_egl_context(device)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 360, in get_egl_context
    self._egl_contexts[cuda_device_id] = EGLContext(
                                         ^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 187, in __init__
    self.device = _get_cuda_device(self.cuda_device_id)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 121, in _get_cuda_device
    raise ValueError(
ValueError: Found 4 CUDA devices, but none with CUDA id 1.
ERROR
test_draw_square (tests.test_opengl_utils.TestOpenGLSingleThreaded.test_draw_square) ... ok
test_render_two_squares (tests.test_opengl_utils.TestOpenGLSingleThreaded.test_render_two_squares) ... ok
test_device_context_store (tests.test_opengl_utils.TestOpenGLUtils.test_device_context_store) ... EGL could not release context on device cuda:0. This can happen if you created two contexts on the same device. Instead, you can use DeviceContextStore to use a single context per device, and EGLContext.make_(in)active_in_current_thread to (in)activate the context as needed.
ok
test_egl_release_error (tests.test_opengl_utils.TestOpenGLUtils.test_egl_release_error) ... EGL could not release context on device cuda:0. This can happen if you created two contexts on the same device. Instead, you can use DeviceContextStore to use a single context per device, and EGLContext.make_(in)active_in_current_thread to (in)activate the context as needed.
ok
test_no_egl_error (tests.test_opengl_utils.TestOpenGLUtils.test_no_egl_error) ... FAIL
test_egl_convert_to_int_array (tests.test_opengl_utils.TestUtils.test_egl_convert_to_int_array) ... ok
test_get_cuda_device (tests.test_opengl_utils.TestUtils.test_get_cuda_device) ... ok
test_load_extensions (tests.test_opengl_utils.TestUtils.test_load_extensions) ... ok

======================================================================
ERROR: test_cuda_context (tests.test_opengl_utils.TestDeviceContextStore.test_cuda_context)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 95, in test_cuda_context
    cuda_context_3 = global_device_context_store.get_cuda_context(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 339, in get_cuda_context
    self._cuda_contexts[cuda_device_id] = _init_cuda_context(cuda_device_id)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 418, in _init_cuda_context
    device = cuda.Device(device_id)
             ^^^^^^^^^^^^^^^^^^^^^^
pycuda._driver.LogicError: cuDeviceGet failed: invalid device ordinal

======================================================================
ERROR: test_egl_context (tests.test_opengl_utils.TestDeviceContextStore.test_egl_context)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 112, in test_egl_context
    egl_context_3 = global_device_context_store.get_egl_context(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 360, in get_egl_context
    self._egl_contexts[cuda_device_id] = EGLContext(
                                         ^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 187, in __init__
    self.device = _get_cuda_device(self.cuda_device_id)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/pytorch3d/renderer/opengl/opengl_utils.py", line 121, in _get_cuda_device
    raise ValueError(
ValueError: Found 4 CUDA devices, but none with CUDA id 1.

======================================================================
ERROR: test_render_two_threads_two_gpus (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_two_threads_two_gpus)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 206, in test_render_two_threads_two_gpus
    self._render_two_threads_two_gpus(_draw_squares_with_context)
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 284, in _render_two_threads_two_gpus
    result[0]["egl"]["context"].address, result[1]["egl"]["context"].address
                                         ~~~~~~~~~^^^^^^^
TypeError: 'NoneType' object is not subscriptable

======================================================================
ERROR: test_render_two_threads_two_gpus_context_store (tests.test_opengl_utils.TestOpenGLMultiThreaded.test_render_two_threads_two_gpus_context_store)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 209, in test_render_two_threads_two_gpus_context_store
    self._render_two_threads_two_gpus(_draw_squares_with_context_store)
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 284, in _render_two_threads_two_gpus
    result[0]["egl"]["context"].address, result[1]["egl"]["context"].address
                                         ~~~~~~~~~^^^^^^^
TypeError: 'NoneType' object is not subscriptable

======================================================================
FAIL: test_no_egl_error (tests.test_opengl_utils.TestOpenGLUtils.test_no_egl_error)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/andrew/jean-zay/projects/ia4markings/pytorch3d_test/tests/test_opengl_utils.py", line 377, in test_no_egl_error
    self.assertFalse(_can_import_egl_and_pycuda())
AssertionError: True is not false

----------------------------------------------------------------------
Ran 17 tests in 1.369s

FAILED (failures=1, errors=4)

Ignoring the multi gpu tests, the only failed test that might be relevant is 'test_no_egl_error' I guess, I'm not sure what to do with this result yet though.

bottler commented 4 months ago

I'm not so worried about test_no_egl_error. Can you try tests.test_rasterizer please?

andrewcaunes commented 4 months ago

Sure, here is what I got : Using : python3 -m unittest tests.test_rasterizer -v Output :

test_compare_rasterizers (tests.test_rasterizer.TestMeshRasterizer.test_compare_rasterizers) ... An exception occurred in telemetry logging.Disabling telemetry to prevent further exceptions.
Traceback (most recent call last):
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/iopath/common/file_io.py", line 946, in __log_tmetry_keys
    handler.log_event()
  File "/home/andrew/miniconda3/envs/ia4mk_3/lib/python3.11/site-packages/iopath/common/event_logger.py", line 97, in log_event
    del self._evt
        ^^^^^^^^^
AttributeError: 'NativePathHandler' object has no attribute '_evt'
Time to transform:  0.000804901123046875
Time to set up:  0.004029273986816406
Time to rasterize:  0.0030531883239746094
ok
test_simple_sphere (tests.test_rasterizer.TestMeshRasterizer.test_simple_sphere) ... Time to transform:  0.0007436275482177734
Time to set up:  2.3603439331054688e-05
Time to rasterize:  0.0004773139953613281
Time to transform:  0.0010900497436523438
Time to set up:  3.600120544433594e-05
Time to rasterize:  0.0012364387512207031
Time to transform:  0.0008120536804199219
Time to set up:  3.6716461181640625e-05
Time to rasterize:  0.0005502700805664062
Time to transform:  0.0009360313415527344
Time to set up:  3.528594970703125e-05
Time to rasterize:  0.0005829334259033203
ok
test_simple_sphere_fisheye (tests.test_rasterizer.TestMeshRasterizer.test_simple_sphere_fisheye) ... Time to transform:  0.0007088184356689453
Time to set up:  1.0967254638671875e-05
Time to rasterize:  0.0004017353057861328
Time to transform:  0.015677690505981445
Time to set up:  1.2874603271484375e-05
Time to rasterize:  0.0004665851593017578
Time to transform:  0.001974821090698242
Time to set up:  1.0013580322265625e-05
Time to rasterize:  0.0008959770202636719
ok
test_simple_sphere_opengl (tests.test_rasterizer.TestMeshRasterizer.test_simple_sphere_opengl) ... ok
test_simple_to (tests.test_rasterizer.TestMeshRasterizer.test_simple_to) ... ok
test_check_cameras (tests.test_rasterizer.TestMeshRasterizerOpenGLUtils.test_check_cameras) ... ok
test_check_raster_settings (tests.test_rasterizer.TestMeshRasterizerOpenGLUtils.test_check_raster_settings) ... ok
test_convert_meshes_to_gl_ndc_square_img (tests.test_rasterizer.TestMeshRasterizerOpenGLUtils.test_convert_meshes_to_gl_ndc_square_img) ... ok
test_parse_and_verify_image_size (tests.test_rasterizer.TestMeshRasterizerOpenGLUtils.test_parse_and_verify_image_size) ... ok
test_simple_sphere (tests.test_rasterizer.TestPointRasterizer.test_simple_sphere) ... ok
test_simple_sphere_fisheye_against_perspective (tests.test_rasterizer.TestPointRasterizer.test_simple_sphere_fisheye_against_perspective) ... ok
test_simple_to (tests.test_rasterizer.TestPointRasterizer.test_simple_to) ... ok

----------------------------------------------------------------------
Ran 12 tests in 11.609s

OK
PyCUDA WARNING: a clean-up operation failed (dead context maybe?)
cuGraphicsUnregisterResource failed: invalid OpenGL or DirectX context

By the way, in the meantine I tried to run my code in another file using the cow mesh for comparison, and the code ran without error, so my mesh is the element triggering the error. I'm guessing it's due to its large size ? Is there a standard way to cut a large mesh into parts that could be batched in pytorch3d please ?

Thanks again for the help

bottler commented 4 months ago

You can take a piece of a mesh with submeshes - https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/structures/meshes.py#L1550 .

I wonder if the problem isn't the size but the layout of the vertices? The error is cuMemcpy2DUnaligned failed: invalid argument on the output of verts_packed(). Can you check that meshes.verts_packed() is actually on the GPU? And maybe see whether it is contiguous?

andrewcaunes commented 4 months ago

I added these lines to _torch_to_opengl :

def _torch_to_opengl(torch_tensor, cuda_context, cuda_buffer):
    import torch
    print("type(torch_tensor):", type(torch_tensor))
    print("isinstance(torch_tensor, torch.Tensor):", isinstance(torch_tensor, torch.Tensor))
    print("torch_tensor.device:", torch_tensor.device)
    print("is contiguous:", torch_tensor.is_contiguous())

and got

type(torch_tensor): <class 'torch.Tensor'>
isinstance(torch_tensor, torch.Tensor): True
torch_tensor.device: cuda:0
is contiguous: True

I also tried to generate a new mesh similar to my previous one but smaller and the code ran successfully, so it would seem that the larger size does trigger the error.