EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
443 stars 62 forks source link

Question about functions that return the same argument they mutate, with different activities for the return and the argument #1411

Closed simsurace closed 1 month ago

simsurace commented 5 months ago

In #1307, the following question was raised. It is strange that the explicit autodiff call seems to be inconsistent with EnzymeTestUtils behavior here:

using TestEnv; TestEnv.activate()
using Enzyme, EnzymeTestUtils, LinearAlgebra, Test

Te = Float64
A = exp(Symmetric(rand(Te, 4, 4)))
B = rand(Te, 4)
C = cholesky(A)

f(x, y) = ldiv!(x, y)

res1 = let dB = ones(Te, 4), dC = make_zero(C)
    ret, dret = autodiff(Forward, f, Duplicated, Duplicated(C, dC), Duplicated(B, dB))
    dB, dret
end

res2 = let dB = ones(Te, 4)
    ret, dret = autodiff(Forward, f, Duplicated, Const(C), Duplicated(B, dB))
    dB, dret
end

res1 .≈ res2 # (true, true)

Enzyme.API.runtimeActivity!(true)
test_forward(f, Duplicated, (C, Const), (B, Duplicated)) # 1/11 tests fails
test_forward(f, Duplicated, (C, Duplicated), (B, Duplicated)) # 11/11 tests pass

It was commented by @sethaxen, that

The problem seems to be that if you pass Const(C) but the other activities are Duplicated, then Enzyme turns the Const(C) into a Duplicated(C, C) (can see this by adding a @show typeof(fact), fact.val === fact.dval in the rule and then calling autodiff(Forward, f, Duplicated, Const(C), Duplicated(B, dB)). This results in the absolute wrong derivative being computed, since instead of being a zero, the shadow is the primal. @wsmoses, this looks like an Enzyme bug, but I'm not certain how to raise it in an issue, since it can only be caught within a rule.

It is unclear if this activity pattern should be allowed at all, but something like the following works:

julia> function f!(x, y)
           y .= x .* y
           return y
       end
f! (generic function with 1 method)

julia> x = rand(4); y = rand(4); dy = ones(4);

julia> autodiff(Forward, f!, Const, Const(x), Duplicated(y, dy))
()

julia> dy ≈ x
true
wsmoses commented 5 months ago

cc @sethaxen

so runtime activity does come with a significant sharp edge -- which I think is causing the test to fail.

Specifically from the docs on runtime activity:

Instead, Enzyme has a special mode known as "Runtime Activity" which can handle these types of situations. It can come with a minor performance reduction, and is therefore off by default. It can be enabled with Enzyme.API.runtimeActivity!(true) right after importing Enzyme for the first time.

The way Enzyme's runtime activity resolves this issue is to return the original primal variable as the derivative whenever it needs to denote the fact that a variable is a constant. As this issue can only arise with mutable variables, they must be represented in memory via a pointer. All addtional loads and stores will now be modified to first check if the primal pointer is the same as the shadow pointer, and if so, treat it as a constant. Note that this check is not saying that the same arrays contain the same values, but rather the same backing memory represents both the primal and the shadow (e.g. a === b or equivalently pointer(a) == pointer(b)).

Enabling runtime activity does therefore, come with a sharp edge, which is that if the computed derivative of a function is mutable, one must also check to see if the primal and shadow represent the same pointer, and if so the true derivative of the function is actually zero.

Do we need to extend enzymetestutils to check if runtime activity is on -- then if the returned result is the same as a primal pointer it is equivelant to 0.

In the case of the rule, I think it is likely the case presently that the rule you're working on doesn't support runtime activity at the moment (and most julia rules doesn't, but all enzyme internal ones and ones in enzyme.jl generally do). Specifically if pointer(a) == pointer(da) it needs to be considered constant if runtime activity is on.

We do try to explicitly convert runtime activity into static activity for the args of rules though to avoid giving users this complexity (https://github.com/EnzymeAD/Enzyme.jl/blob/1e27530c10989926c45377e1efd47f047415603e/src/rules/jitrules.jl#L365) but if its an array in a struct, that may need user level checks

wsmoses commented 5 months ago

A constant return generally is permitted, the issue here is when turning a constant input into a duplicated return.

wsmoses commented 5 months ago

Oh perhaps I'm not understanidng this issue properly, will try to digest this at length shortly.

wsmoses commented 1 month ago

I think this is answered above, if this is still unclear, reopen.