I was benchmarking some very simple optimization I had and was surprised that it wasn't running as fast as expected. I am running on GPU and was hoping for near-realtime performance (even with Python re-tracing).
It turns out that the explicit gc.collect() calls in the RenderOp are actually quite expensive. When optimizing the color of a wall in the cornell box using PRB, the current master branch takes 27s for 1024 iterations on an RTX A5000. By removing the explicit gc.collect calls, I get down to 11s. The performance increases from 37 iterations per second to almost 90.
Given this result, I think we should consider getting rid of these gc.collect() calls in the innermost loop. As far as I know, these go back to pre-PRB, pre-Dr.Jit times when we often ran out of memory due to the AD graph.
Here is the reproducer I run:
import drjit as dr
import mitsuba as mi
import tqdm
def run_optim():
mi.set_variant('cuda_ad_rgb')
scene = mi.cornell_box()
scene['integrator'] = {'type': 'prb'}
scene = mi.load_dict(scene)
image_ref = mi.render(scene, spp=512)
params = mi.traverse(scene)
key = 'red.reflectance.value'
print("Reference param ", params[key])
params[key] = mi.Color3f(0.01, 0.2, 0.9)
params.update()
opt = mi.ad.Adam(lr=0.05)
opt[key] = params[key]
params.update(opt)
for _ in tqdm.tqdm(range(1024)):
image = mi.render(scene, params, spp=4)
loss = dr.mean(dr.square(image - image_ref))
dr.backward(loss)
opt.step()
opt[key] = dr.clip(opt[key], 0.0, 1.0)
params.update(opt)
print("Final param ", params[key])
run_optim()
I was benchmarking some very simple optimization I had and was surprised that it wasn't running as fast as expected. I am running on GPU and was hoping for near-realtime performance (even with Python re-tracing).
It turns out that the explicit
gc.collect()
calls in the RenderOp are actually quite expensive. When optimizing the color of a wall in the cornell box using PRB, the current master branch takes 27s for 1024 iterations on an RTX A5000. By removing the explicit gc.collect calls, I get down to 11s. The performance increases from 37 iterations per second to almost 90.Given this result, I think we should consider getting rid of these gc.collect() calls in the innermost loop. As far as I know, these go back to pre-PRB, pre-Dr.Jit times when we often ran out of memory due to the AD graph.
Here is the reproducer I run: