Closed aroidzap closed 6 months ago
Regarding the recursion error, recursion happens in: https://github.com/shader-slang/slang/blob/56928794d0800824dc91e150cb345b5fec24d930/source/slang/slang-ir-autodiff-transcriber-base.cpp#L152
Slang source to reproduce the issue:
[Differentiable]
float3 test(float3 val)
{
return float3(cos(val.x), sin(val.y), exp(val.z));
}
[shader("compute")]
[numthreads(16, 16, 1)]
void test_main(uint3 tid: SV_DispatchThreadID)
{
float3 a = float3(tid);
float3 b = float3(0);
DifferentialPair<DifferentialPair<float3>> dp_dp_a = diffPair(diffPair(a));
DifferentialPair<float3> dp_b = diffPair(b);
bwd_diff(bwd_diff(test))(dp_dp_a, dp_b);
}
Thank you for providing a minimal reproducer! This will make it much simpler to figure out what's wrong. Will provide more updates soon.
I've done some digging into this issue. It appears that there's several unhandled cases when generics and multiple reverse-mode passes interact.
The specific issue is that reverse-mode auto-diff produces function dependent types (the intermediate context data structure depends on the function contents). These types are not handled correctly when they appear within generic containers (all stdlib methods like cos & sin are generics).
This doesn't matter for a single reverse-mode pass or reverse-mode over forward-mode. However, it becomes an issue with multiple layered bwd_diff
calls.
Implementing the fixes is slightly involved so I'll continue to work on it & provide updates (I'm moving it to Q2 2024)
In the meantime, this can be worked around by using fwd_diff
for the inner call instead of bwd_diff
(i.e. bwd_diff(fwd_diff(fn))
) to compute Hessians more efficiently.
Note that multiple reverse-mode passes (as shown in your snippet) leads to extremely inefficient code, and is usually not the recommended approach to compute Hessians.
@aroidzap Are you currently blocked by this issue? Does changing to bwd_diff(fwd_diff(f))
work for you?
@aroidzap Are you currently blocked by this issue? Does changing to
bwd_diff(fwd_diff(f))
work for you?
Hello, thanks for asking, this issue does not block my work. I was just curious to try higher order diff and I encountered this bug. I haven't got time yet to check if bwd_diff(fwd_diff(f))
works in my code.
Looks like this issue can be closed. @aroidzap please reopen if required.
1) For some reason, float values marked as
no_diff
are not treated asno_diff
in higher order derivatives (i.e.no_diff float
in a forward is alsono_diff float
inbwd_diff(func)(...)
call, but isDifferentialPair<float>
inbwd_diff(bwd_diff(func))(...)
call).uint
values are working as expected.2) Usage of built-in
cos()
,sin()
andexp()
functions in higher order diff causes slang compiler to crash with3221225725 (0xc00000fd)
(Recursion Error). As a workaround the Taylor expansions of these functions can be used.