Open dvicini opened 2 months ago
Hi @dvicini
It looks like the syntax rewriting is a bit aggressive here: it thinks a
is part of the loop state. This means that a
gets re-assigned at the end the loop and hence within the suspend_grad
. Chaning the loop to while dr.hint(i < 5, exclude=[a]):
fixes the issue.
I'll look into why it thinks a
should be in the state
. As a general rule, the @dr.syntax
tends to be "safer" than necessary in order to guarantee the loop is valid, but that can come at a cost as you see here.
Ok, there's not much we can do here, I believe.
In short, even though it's clear in your reproducer that a
is not being written to in the loop, we cannot guarantee that in a more general case. For example, a random number generator will effectively never be on the left-hand side of an assignement (i.e sampler = sampler.next_1d()
), but it must still be considered as part of the loop state in order to work because it might evolve implicitly (i.e sampler.next_1d()
). So, in this case, we still consider a
to be in the loop state because even though it's only use on the right-hand side of an assignment, it might have evolved implicitly.
The workflow for these kind of situations is usually to add print_code=True
to @dr.syntax()
and look at the rewritten function. It's usually fairly obvious if too many variables are included and that you should then specify then in the exclude list of a dr.hint
statement.
Fair enough, thanks for checking. I have to say I wasn't super aware of the various debug options for dr.syntax
, but with print_code=True
and the exclude hints, this should be okay in practice.
The loop constructs, dr.syntax and dr.hint have extensive documentation. Please take a look and post an issue/PR if anything should be added. We should document any potential gotchas.
I've ran into another seemingly related problem: The following code silently produces a result
value of 0, instead of 1 as it should. If I remove the if-statement, it works as expected. I tried adding dr.hint(..., exclude=[b])
, but that produces an error RuntimeError: ad_traverse(): tried to forward-propagate derivatives across edge a1 -> a2, which lies outside of the current dr.isolate_grad() scope
import drjit as dr
import mitsuba as mi
mi.set_variant('llvm_ad_rgb')
@dr.syntax(print_code=True)
def f():
param = mi.Float(0.0)
dr.enable_grad(param)
dr.set_grad(param, 1.0)
a = dr.linspace(mi.Float, 1, 2, 16) + param
result = mi.Float(0.0)
b = dr.gather(mi.Float, a, 3)
if b == b: # Always true
result += dr.forward_to(b) # Fails silently
# Doing the same without the if-statement works as expected
# result += dr.forward_to(b)
return result
result = f()
print(result)
The last issue you posted is unrelated to the problem with dr.suspend_grad
and symbolic operations. Could you create a separate issue for it? It's actually not quite sure how this should behave in general. Suppose we have an arbitrarily nested sequence of symbolic operations, and the user calls dr.forward
at the innermost level. The system then has to kind of travel back in time and forward-propagate derivatives into each of the outer scopes.
The way we handled this in Mitsuba before is that you had to do a forward AD pass outside of the symbolic region, whose derivative values can then be picked up. But this of course isn't fully general.
Yes you are right, I created a separate issue to track this: #295
I have some code that mixes AD with handwritten derivatives & loops (e.g., similar to something like PRB)
It appears that within
dr.syntax
, any use of a differentiable variable within a loop disables the variable's AD graph.Here is an example:
This prints
But I would have expected this to print
The current behavior is a bit unintuitive, and leads to confusing loss of gradient tracking. To me it seems that the loop should have no influence on whether
a
has gradients enabled or not, similar to the pre-dr.syntax behavior.