mitsuba-renderer / mitsuba3

Mitsuba 3: A Retargetable Forward and Inverse Renderer
https://www.mitsuba-renderer.org/
Other
2.05k stars 240 forks source link

AD & Loop #761

Closed zichenwang01 closed 1 year ago

zichenwang01 commented 1 year ago

I ran into the following problem and do not know how to come around this-- drjit.Exception: loop_process_state(): one of the supplied loop state variables of type Float is attached to the AD graph (i.e., grad_enabled(..) is true). However, propagating derivatives through multiple iterations of a recorded loop is not supported (and never will be). Please see the documentation on differentiating loops for details and suggested alternatives.

Basically, I have my sdf.grid as mi.Texture3f and pass it into an optimizer--- params = {'grid': sdf.grid.tensor()} optimizer = mi.ad.Adam(lr=0.2, params=params)

Within each iteration of optimization, I use mi.Loop for sphere tracing. Then I ended up in this error. I am also not aware how to use mi.traverse() to pass my sdf parameters into an optimizer, so I directly pass in a mi.TensorXf. It would be nice if I can have a clarification on how to do this with mitsuba.

Thanks!

njroussel commented 1 year ago

Hi @zichenwang01

This is expected. As the error suggests, this is simply not supported. In order to track derivatives across a loop, it cannot be a recorded loop. Internally, Mitsuba disables recorded loops when running the backward pass (source code). However, this comes at a significant performance cost. The recommended solution is to compute the gradients with an adjoint method like the prb-style integrators we have implemented. (Take a look at this discussion too: https://github.com/mitsuba-renderer/drjit/issues/53).

For SDFs specifcally, you might want to take a look at this project too: https://github.com/rgl-epfl/differentiable-sdf-rendering/tree/main

As for the TensorXf/Texture3f question, this is the correct way to do it. The Texture3f is a just a nice interface around a TensorXf which actually holds the data. You should therefore expose and use the TensorXf object through mi.traverse().