mitsuba-renderer / mitsuba3

Mitsuba 3: A Retargetable Forward and Inverse Renderer
https://www.mitsuba-renderer.org/
Other
2.1k stars 246 forks source link

Segmentation fault in loss.backward() when defining a new integrator #1198

Open kyleleey opened 5 months ago

kyleleey commented 5 months ago

Summary

Loss.backward() would raise Segmentation Fault when defining a new integrator to render the scene and compute loss on this image.

System configuration

System information:

OS: Ubuntu 20.04.6 LTS CPU: AMD EPYC 9334 32-Core Processor GPU: NVIDIA L40S Python: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0] NVidia driver: 550.54.14 LLVM: 12.0.0

Dr.Jit: 0.4.4 Mitsuba: 3.5.0 Is custom build? False Compiled with: GNU 10.2.1 Variants: scalar_rgb scalar_spectral cuda_ad_rgb llvm_ad_rgb

Description

Hi,

I'm using mitsuba3 with Pytorch for object pose estimation, I noticed in the official tutorial that for differentiable simulation, the integrator has to be "direct_projective", but when I compute loss on rendered image with new integrator, the loss.backward() will raise Segmentation Fault without other messages. (If with default "path" integrator the gradient will easily be nan, I assume this is the reason why the "direct_projective" integrator is suggested)

Steps to reproduce

Minimal Reproduce:


@dr.wrap_ad(source='torch', target='drjit')
def render_scene_w_pose(
    scene,
    trans=None,
    cam=None,
    spp=256,
    seed=1
):
    sensor = cam
    integrator = mi.load_dict(
        {
            'type': 'direct_projective',
        }
    )

    num_shapes = trans.shape[0]

    params = mi.traverse(scene)
    for i in range(num_shapes):
        initial_vertex_positions = dr.unravel(mi.Point3f, params[f'{i:02d}.vertex_positions'])
        translate = mi.Point3f(trans[i,0].array, trans[i,1].array, trans[i,2].array)
        trafo = mi.Transform4f.translate(translate)
        params[f'{i:02d}.vertex_positions'] = dr.ravel(trafo @ initial_vertex_positions)
    params.update()

    # if use the default 'path' integrator in scene the loss.backward() won't raise segmentation fault, but with nan gradient
    image = mi.render(scene, params=params, sensor=sensor, integrator=integrator, spp=spp, seed=seed, seed_grad=seed+1)
    # image = mi.render(scene, params=params, sensor=sensor, spp=spp, seed=seed, seed_grad=seed+1)
    return image, scene

'''
defined elsewhere:
num_shapes
LR=1e-3
'''

trans_x_params = nn.Parameter(torch.ones(num_shapes, 1, device=device) * 0.01, requires_grad=True)
trans_y_params = nn.Parameter(torch.ones(num_shapes, 1, device=device) * 0.01, requires_grad=True)
trans_z_params = nn.Parameter(torch.ones(num_shapes, 1, device=device) * 0.01, requires_grad=True)

optimize_list = [
    {'name': 'trans_x', 'params': list([trans_x_params]), 'lr': LR},
    {'name': 'trans_y', 'params': list([trans_y_params]), 'lr': LR},
    {'name': 'trans_z', 'params': list([trans_z_params]), 'lr': LR},
]

optimizer = torch.optim.Adam(optimize_list, betas=(0.9, 0.99), eps=1e-15)

optimizer.zero_grad()
trans_pred = torch.cat([trans_x_params, trans_y_params, trans_z_params], dim=-1)
render_img, scene = render_scene_w_pose(scene, cam=render_cam, trans=trans_pred, spp=64, seed=0)

# omit other pytorch operations
loss = render_img.mean()
loss.backward()
njroussel commented 5 months ago

Hi @kyleleey

Sorry, I'm a bit confused by what you mean with "new integrator". Are we only speaking about direct_projective or do you have your own custom integrator?

Can you try moving the integrator construction outside of the function? Sometimes its lifetime isn't well defined.