mitsuba-renderer / mitsuba3

Mitsuba 3: A Retargetable Forward and Inverse Renderer
2.09k stars 246 forks source link

is using Pytorch network in mi.Loop allowed? #1332

Open brabbitdousha opened 1 month ago

brabbitdousha commented 1 month ago


Hi, I am using pytorch with misuba3, and I need to do pytorch network inference in rendering, here is a pseudo code of my working flow: I am calling trainer.eval() (which has a pytorch network inside) in a mi.Loop, for example if I want to do neural importance sampling, I need to insert a network in rendering, and after the rendering is over, I will update the network using trainer.train()

class MyPathIntegrator(mi.SamplingIntegrator):
    """Simple path tracer with MIS + NEE."""

    def path_tracing(self,
               scene: mi.Scene,
               sampler: mi.Sampler,
               ray: mi.Ray3f,
               medium: mi.Medium = None, 
               active: bool = True

        loop = mi.Loop(name="Custom Path Tracer",
                       state=lambda: (sampler, ray, depth, cur_depth, L, β, η, active,
                                      prev_si, prev_bsdf_pdf, prev_bsdf_delta))


        while loop(active):

            test_pts = (pts - scene.bbox().min) / (scene.bbox().max - scene.bbox().min)
            output = trainer.eval(test_pts)

            cur_depth += 1

        return (L, dr.neq(depth, 0), [pts, albedo, output])

    def sample(self,
               scene: mi.Scene,
               sampler: mi.Sampler,
               ray: mi.Ray3f,
               medium: mi.Medium = None, 
               active: bool = True
        # --------------------- Configure loop state ----------------------

        (color, mask, aov) = self.path_tracing(scene, sampler, ray, medium, active)

        return (color, mask, aov)

def run_render(scene_path, spp, h, w, device):

    # Register new integrator

    mi.register_integrator("mypath", lambda props: MyPathIntegrator(props))

    scene = mi.load_file(scene_path)
    # Render
    with dr.suspend_grad():
        for i in range(16):
            img, aov = mi.render(scene, spp=spp)


However, after updating the network in trainer.train(), the output of trainer.eval() in mi.Loop doesn't update.... I tried with ```#dr.set_flag(dr.JitFlag.VCallRecord, False)

dr.set_flag(dr.JitFlag.LoopRecord, False)```, everything becomes right, but this is much slower,

so with these two flags enabled, using pytorch network in mi.Loop is not allowed? I only do network inference in mi.Loop, and I am not using differentiable rendering.

System configuration

System information:

OS: windows CPU: intel i9-13900H GPU: RTX 4060 laptop Python version: 3.9 CUDA version: 12.0 NVidia driver: 550.54.14

Dr.Jit version: 0.4.4 Mitsuba version: 3.5.0

merlinND commented 1 month ago

Hello @brabbitdousha,

The goal of mi.Loop() it to create a loop inside of the kernel being recorded. By definition, the body of that loop must be part of that same kernel. Calling functions from another framework, which cannot be traced by DrJit and therefore cannot be included in the body of the loop, cannot be supported.

As you have noticed, disabling loop recording (dr.JitFlag.LoopRecord) makes it possible to call into the other framework inside of the loop body, with the effect of breaking up the megakernel and incurring a lot of overhead (e.g. to read / write the results of each kernel from / to global memory).