mitsuba-renderer / drjit

Dr.Jit — A Just-In-Time-Compiler for Differentiable Rendering
BSD 3-Clause "New" or "Revised" License
593 stars 43 forks source link

dr.syntax: AD gets disabled by variable use in while loop #253

Open dvicini opened 2 months ago

dvicini commented 2 months ago

I have some code that mixes AD with handwritten derivatives & loops (e.g., similar to something like PRB)

It appears that within dr.syntax, any use of a differentiable variable within a loop disables the variable's AD graph.

Here is an example:

import drjit as dr

@dr.syntax
def f():

  i = dr.zeros(dr.llvm.Int32, 10)
  result = dr.zeros(dr.llvm.ad.Float, 10)

  a = dr.linspace(dr.llvm.ad.Float, 0, 1, 10)
  dr.enable_grad(a)
  print(dr.grad_enabled(a))

  with dr.suspend_grad():
    while i < 5:
      result += a

  print(dr.grad_enabled(a))

f()

This prints

True
False

But I would have expected this to print

True
True

The current behavior is a bit unintuitive, and leads to confusing loss of gradient tracking. To me it seems that the loop should have no influence on whether a has gradients enabled or not, similar to the pre-dr.syntax behavior.

njroussel commented 2 months ago

Hi @dvicini

It looks like the syntax rewriting is a bit aggressive here: it thinks a is part of the loop state. This means that a gets re-assigned at the end the loop and hence within the suspend_grad. Chaning the loop to while dr.hint(i < 5, exclude=[a]): fixes the issue.

I'll look into why it thinks a should be in the state. As a general rule, the @dr.syntax tends to be "safer" than necessary in order to guarantee the loop is valid, but that can come at a cost as you see here.

njroussel commented 2 months ago

Ok, there's not much we can do here, I believe.

In short, even though it's clear in your reproducer that a is not being written to in the loop, we cannot guarantee that in a more general case. For example, a random number generator will effectively never be on the left-hand side of an assignement (i.e sampler = sampler.next_1d()), but it must still be considered as part of the loop state in order to work because it might evolve implicitly (i.e sampler.next_1d()). So, in this case, we still consider a to be in the loop state because even though it's only use on the right-hand side of an assignment, it might have evolved implicitly.

The workflow for these kind of situations is usually to add print_code=True to @dr.syntax() and look at the rewritten function. It's usually fairly obvious if too many variables are included and that you should then specify then in the exclude list of a dr.hint statement.

dvicini commented 2 months ago

Fair enough, thanks for checking. I have to say I wasn't super aware of the various debug options for dr.syntax, but with print_code=True and the exclude hints, this should be okay in practice.

wjakob commented 2 months ago

The loop constructs, dr.syntax and dr.hint have extensive documentation. Please take a look and post an issue/PR if anything should be added. We should document any potential gotchas.

dvicini commented 1 month ago

I've ran into another seemingly related problem: The following code silently produces a result value of 0, instead of 1 as it should. If I remove the if-statement, it works as expected. I tried adding dr.hint(..., exclude=[b]), but that produces an error RuntimeError: ad_traverse(): tried to forward-propagate derivatives across edge a1 -> a2, which lies outside of the current dr.isolate_grad() scope

import drjit as dr
import mitsuba as mi

mi.set_variant('llvm_ad_rgb')

@dr.syntax(print_code=True)
def f():
  param = mi.Float(0.0)
  dr.enable_grad(param)
  dr.set_grad(param, 1.0)

  a = dr.linspace(mi.Float, 1, 2, 16) + param

  result = mi.Float(0.0)
  b = dr.gather(mi.Float, a, 3)
  if b == b: # Always true
    result += dr.forward_to(b)  # Fails silently

  # Doing the same without the if-statement works as expected
  # result += dr.forward_to(b)

  return result

result = f()
print(result)
wjakob commented 1 month ago

The last issue you posted is unrelated to the problem with dr.suspend_grad and symbolic operations. Could you create a separate issue for it? It's actually not quite sure how this should behave in general. Suppose we have an arbitrarily nested sequence of symbolic operations, and the user calls dr.forward at the innermost level. The system then has to kind of travel back in time and forward-propagate derivatives into each of the outer scopes.

The way we handled this in Mitsuba before is that you had to do a forward AD pass outside of the symbolic region, whose derivative values can then be picked up. But this of course isn't fully general.

dvicini commented 1 month ago

Yes you are right, I created a separate issue to track this: #295