EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Dropped gradient only in reverse mode with cached values #1492

Closed ChrisRackauckas closed 4 weeks ago

ChrisRackauckas commented 1 month ago
using Enzyme

not_decapode_f = begin
    function simulate()
        begin
            var"__•1" = Vector{Float64}(undef, 1)
            __V̇ = Vector{Float64}(undef, 1)
        end
        f(du, u, p, t) = begin
                begin
                    X = u[1]
                    k = p[1]
                    V = u[2]
                end
                var"•1" = var"__•1"
                V̇ = __V̇
                var"•1" .= (.-)(k)
                V̇ .= var"•1" .* X
                du[1] = V
                du[2] = V̇[1]
                nothing
            end
    end
end

du = [0.0,0.0]
u = [1.0,3.0]
p = [1.0]
f = not_decapode_f()
f(du,u,p,1.0)
df = Enzyme.make_zero(f)
d_du = Enzyme.make_zero(du)
d_u = Enzyme.make_zero(u)
dp = Enzyme.make_zero(p)

d_du .= 0; d_u .= 0; dp[1] = 1.0
Enzyme.autodiff(Enzyme.Forward, Duplicated(f, df), Enzyme.Duplicated(du, d_du),
                Enzyme.Duplicated(u,d_u), Enzyme.Duplicated(p,dp),
                Enzyme.Const(1.0))

d_du # [0.0,-1.0]
dp # [1.0]
du # [3.0,-1.0]

d_du .= 0; d_u .= 0; dp[1] = 1.0
Enzyme.autodiff(Enzyme.Reverse, Duplicated(f, df), Enzyme.Duplicated(du, d_du),
                Enzyme.Duplicated(u,d_u), Enzyme.Duplicated(p,dp),
                Enzyme.Const(1.0))

d_du # [0.0,0.0]
dp # [3.0]
du # [3.0,-1.0]

# Confirm with finite difference

ppet = [1 + 1e-7]
du2 = copy(du)
f(du,u,p,1.0)
f(du2,u,ppet,1.0)
(du2 - du) ./ 1e-7 # [0.0,-1.0000000005838672]

MWE of https://github.com/DARPA-ASKEM/sciml-service/issues/177

vchuravy commented 1 month ago

So dp being wrong is due to df holding on to values.

julia> Enzyme.autodiff(Enzyme.Forward, Duplicated(f, df), Enzyme.Duplicated(du, d_du),
                       Enzyme.Duplicated(u,d_u), Enzyme.Duplicated(p,dp),
                       Enzyme.Const(1.0))
()

julia> df.__V̇
1-element Vector{Float64}:
 -1.0

julia> df.var"__•1"
1-element Vector{Float64}:
 -1.0

If you df = Enzyme.make_zero(f) or manually zero-out the temporaries there.

julia> d_du .= 0; d_u .= 0; dp[1] = 1.0;

julia> df = Enzyme.make_zero(f)
f (generic function with 1 method)

julia> Enzyme.autodiff(Enzyme.Reverse, Duplicated(f, df), Enzyme.Duplicated(du, d_du),
                       Enzyme.Duplicated(u,d_u), Enzyme.Duplicated(p,dp),
                       Enzyme.Const(1.0))
((nothing, nothing, nothing, nothing),)

julia> d_du # [0.0,0.0]
2-element Vector{Float64}:
 0.0
 0.0

julia> dp # [3.0]
1-element Vector{Float64}:
 1.0

julia> du # [3.0,-1.0]
2-element Vector{Float64}:
  3.0
 -1.0

Then dp is correct, but d_du still isn't.

ChrisRackauckas commented 1 month ago

d_du is the one I have to actually use, so it's a bit worrisome 😅

vchuravy commented 1 month ago

Yeah, just letting you know that in any case you will have to zero the temporaries in df.

ChrisRackauckas commented 1 month ago

Does Enzyme have a utility to fill zero?

wsmoses commented 4 weeks ago

I'll see if I can look at this tomorrow when flying back to the US.

wsmoses commented 4 weeks ago

@ChrisRackauckas this is a calling convention issue from your end.

In reverse mode you need to set the shadow of the return to 1, not the shadow of the input.

using Enzyme

not_decapode_f = begin
    function simulate()
        begin
            var"__•1" = Vector{Float64}(undef, 1)
            __V̇ = Vector{Float64}(undef, 1)
        end
        f(du, u, p, t) = begin
                begin
                    X = u[1]
                    k = p[1]
                    V = u[2]
                end
                var"•1" = var"__•1"
                V̇ = __V̇
                var"•1" .= (.-)(k)
                V̇ .= var"•1" .* X
                du[1] = V
                du[2] = V̇[1]
                nothing
            end
    end
end

du = [0.0,0.0]
u = [1.0,3.0]
p = [1.0]
f = not_decapode_f()
f(du,u,p,1.0)
df = Enzyme.make_zero(f)
d_du = Enzyme.make_zero(du)
d_u = Enzyme.make_zero(u)
dp = Enzyme.make_zero(p)

df = Enzyme.make_zero(f)
d_du .= 0; d_u .= 0; dp[1] = 1.0
Enzyme.autodiff(Enzyme.Reverse, Duplicated(f, df), Enzyme.Duplicated(du, d_du),
                Enzyme.Duplicated(u,d_u), Enzyme.Duplicated(p,dp),
                Enzyme.Const(1.0))

@show d_du # [0.0,0.0]
@show dp # [3.0]
@show du # [3.0,-1.0]

df = Enzyme.make_zero(f)
# compute the gradient wrt d_u[1]
d_du = [0.0, 1.0]; d_u .= 0; dp[1] = 0.0
Enzyme.autodiff(Enzyme.Reverse, Duplicated(f, df), Enzyme.Duplicated(du, d_du),
                Enzyme.Duplicated(u,d_u), Enzyme.Duplicated(p,dp),
                Enzyme.Const(1.0))

# derivative of d_u[1] / dp, which is what finite differences computes below [in the second term].
@show dp # dp = [-1.0]

# Confirm with finite difference

ppet = [1 + 1e-7]
du2 = copy(du)
f(du,u,p,1.0)
f(du2,u,ppet,1.0)
@show (du2 - du) ./ 1e-7 # [0.0,-1.0000000005838672]
ChrisRackauckas commented 4 weeks ago

oh duh 🤦