mitsuba-renderer / mitsuba3

Mitsuba 3: A Retargetable Forward and Inverse Renderer
https://www.mitsuba-renderer.org/
Other
2.09k stars 246 forks source link

Remove gc.collect() calls from RenderOp #1354

Closed dvicini closed 4 weeks ago

dvicini commented 4 weeks ago

I was benchmarking some very simple optimization I had and was surprised that it wasn't running as fast as expected. I am running on GPU and was hoping for near-realtime performance (even with Python re-tracing).

It turns out that the explicit gc.collect() calls in the RenderOp are actually quite expensive. When optimizing the color of a wall in the cornell box using PRB, the current master branch takes 27s for 1024 iterations on an RTX A5000. By removing the explicit gc.collect calls, I get down to 11s. The performance increases from 37 iterations per second to almost 90.

Given this result, I think we should consider getting rid of these gc.collect() calls in the innermost loop. As far as I know, these go back to pre-PRB, pre-Dr.Jit times when we often ran out of memory due to the AD graph.

Here is the reproducer I run:


import drjit as dr
import mitsuba as mi
import tqdm

def run_optim():
    mi.set_variant('cuda_ad_rgb')
    scene = mi.cornell_box()
    scene['integrator'] = {'type': 'prb'}

    scene = mi.load_dict(scene)
    image_ref = mi.render(scene, spp=512)
    params = mi.traverse(scene)
    key = 'red.reflectance.value'
    print("Reference param ", params[key])
    params[key] = mi.Color3f(0.01, 0.2, 0.9)
    params.update()

    opt = mi.ad.Adam(lr=0.05)
    opt[key] = params[key]
    params.update(opt)
    for _ in tqdm.tqdm(range(1024)):
        image = mi.render(scene, params, spp=4)
        loss = dr.mean(dr.square(image - image_ref))
        dr.backward(loss)
        opt.step()
        opt[key] = dr.clip(opt[key], 0.0, 1.0)
        params.update(opt)
    print("Final param ", params[key])

run_optim()