shader-slang / slang

Making it easier to work with shaders
http://shader-slang.com
MIT License
2.19k stars 187 forks source link

Higher order auto-diff issues #3775

Closed aroidzap closed 6 months ago

aroidzap commented 8 months ago

1) For some reason, float values marked as no_diff are not treated as no_diff in higher order derivatives (i.e. no_diff float in a forward is also no_diff float in bwd_diff(func)(...) call, but is DifferentialPair<float> in bwd_diff(bwd_diff(func))(...) call). uint values are working as expected.

2) Usage of built-in cos() ,sin() and exp() functions in higher order diff causes slang compiler to crash with 3221225725 (0xc00000fd) (Recursion Error). As a workaround the Taylor expansions of these functions can be used.

aroidzap commented 7 months ago

Regarding the recursion error, recursion happens in: https://github.com/shader-slang/slang/blob/56928794d0800824dc91e150cb345b5fec24d930/source/slang/slang-ir-autodiff-transcriber-base.cpp#L152

Slang source to reproduce the issue:

[Differentiable]
float3 test(float3 val)
{
    return float3(cos(val.x), sin(val.y), exp(val.z));
}

[shader("compute")]
[numthreads(16, 16, 1)]
void test_main(uint3 tid: SV_DispatchThreadID)
{
    float3 a = float3(tid);
    float3 b = float3(0);
    DifferentialPair<DifferentialPair<float3>> dp_dp_a = diffPair(diffPair(a));
    DifferentialPair<float3> dp_b = diffPair(b);
    bwd_diff(bwd_diff(test))(dp_dp_a, dp_b);
}
saipraveenb25 commented 7 months ago

Thank you for providing a minimal reproducer! This will make it much simpler to figure out what's wrong. Will provide more updates soon.

saipraveenb25 commented 7 months ago

I've done some digging into this issue. It appears that there's several unhandled cases when generics and multiple reverse-mode passes interact.

The specific issue is that reverse-mode auto-diff produces function dependent types (the intermediate context data structure depends on the function contents). These types are not handled correctly when they appear within generic containers (all stdlib methods like cos & sin are generics). This doesn't matter for a single reverse-mode pass or reverse-mode over forward-mode. However, it becomes an issue with multiple layered bwd_diff calls.

Implementing the fixes is slightly involved so I'll continue to work on it & provide updates (I'm moving it to Q2 2024)

In the meantime, this can be worked around by using fwd_diff for the inner call instead of bwd_diff (i.e. bwd_diff(fwd_diff(fn))) to compute Hessians more efficiently. Note that multiple reverse-mode passes (as shown in your snippet) leads to extremely inefficient code, and is usually not the recommended approach to compute Hessians.

csyonghe commented 6 months ago

@aroidzap Are you currently blocked by this issue? Does changing to bwd_diff(fwd_diff(f)) work for you?

aroidzap commented 6 months ago

@aroidzap Are you currently blocked by this issue? Does changing to bwd_diff(fwd_diff(f)) work for you?

Hello, thanks for asking, this issue does not block my work. I was just curious to try higher order diff and I encountered this bug. I haven't got time yet to check if bwd_diff(fwd_diff(f)) works in my code.

bmillsNV commented 6 months ago

Looks like this issue can be closed. @aroidzap please reopen if required.