mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
586 stars 42 forks source link

Forward-mode gradients with symbolic conditionals #295

Open dvicini opened 2 weeks ago

dvicini commented 2 weeks ago

Separating this from issue #253

The following code silently produces a result value of 0, instead of 1 as it should. If I remove the if-statement, it works as expected. I tried adding dr.hint(..., exclude=[b]), but that produces an error RuntimeError: ad_traverse(): tried to forward-propagate derivatives across edge a1 -> a2, which lies outside of the current dr.isolate_grad() scope.

It's not entirely clear what's the right pattern is to get correct results here. I guess one option is to explicitly forward-propagate up to right before the if-conditioned is entered?

import drjit as dr
import mitsuba as mi

mi.set_variant('llvm_ad_rgb')

@dr.syntax(print_code=True)
def f():
  param = mi.Float(0.0)
  dr.enable_grad(param)
  dr.set_grad(param, 1.0)

  a = dr.linspace(mi.Float, 1, 2, 16) + param

  result = mi.Float(0.0)
  b = dr.gather(mi.Float, a, 3)
  # dr.forward_to(b) # One option is to explicitly propagate up to 
  if b == b: # Always true
    result += dr.forward_to(b)  # Fails silently

  # Doing the same without the if-statement works as expected
  # result += dr.forward_to(b)

  return result

result = f()
print(result)
wjakob commented 2 weeks ago

Do you have thoughts on how you would like this to behave? (following the comment here)

dvicini commented 1 week ago

I am honestly not quite sure, thinking about it some more, this seems quite tricky to solve robustly.

I am still trying to see how to best use the if-statements and am not sure yet what usage patterns will emerge.

Maybe it's more of a matter of putting a warning about this in the doc, if not there already