EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
446 stars 63 forks source link

Example of JVP / J'VP with Krylov.jl #1957

Open vchuravy opened 2 weeks ago

vchuravy commented 2 weeks ago

cc: @michel2323 & @amontoison

With @lcandiot I was wondering how to write a proper JVP J'VP with Enzyme and finally converged.

This might turn into a nice example one of these days. @wsmoses any ideas on how to avoid the calls to zero(y) andzero(w)/copy(w)?

using Krylov, Enzyme, LinearOperators, ForwardDiff, LinearAlgebra

xk = ones(2)

F(x) = [x[1]^4 - 3; exp(x[2]) - 2; log(x[1]) - x[2]^2]

function JVP!(y, f::F, u, v) where F 
    Enzyme.autodiff(Forward, 
        (temp, v) -> (temp .= f(v); nothing),
        Const, 
        DuplicatedNoNeed(zero(y), y),
        DuplicatedNoNeed(u, v))
    return nothing
end

"""
Calculate the Jacobian-Transpose Vector Product in-place by updating `y`.
"""
function JᵀVP!(y, f::F, u, w) where F
    y .= 0 # Enzyme expects y to be zero
    Enzyme.autodiff(Enzyme.Reverse, 
        (out, x) -> (out .= f(x); nothing),
        Const, 
        DuplicatedNoNeed(zero(w), copy(w)), # copy since otherwise Enzyme will zero
        DuplicatedNoNeed(u, y))
    return nothing
end

J(y, v) = ForwardDiff.derivative!(y, t -> F(xk + t * v), 0)
Jᵀ(y, u) = ForwardDiff.gradient!(y, x -> dot(F(x), u), xk)

w = rand(3)
v = rand(2)

y_fwd = zeros(2)
Jᵀ(y_fwd, w)
@show y_fwd

y_enz = zeros(2)
@show JᵀVP!(y_enz, F, xk, w)
@show y_enz

@assert y_enz ≈ y_fwd

y2_fwd = zeros(3)
J(y2_fwd, v)
@show y2_fwd

y2_enz = zeros(3)
@show JVP!(y2_enz, F, xk, v)
@show y2_enz

@assert y2_enz ≈ y2_fwd

opJ_FWD = LinearOperator(Float64, 3, 2, false, false, (y, v) -> J(y, v),
                                                  (y, w) -> Jᵀ(y, w),
                                                  (y, u) -> Jᵀ(y, u))

x_forward, _ = lsmr(opJ_FWD, -F(xk))

opJ = LinearOperator(Float64, 3, 2, false, false, (y, v) -> JVP!(y, F, xk, v),
                                                  (y, w) -> JᵀVP!(y,F, xk, w),
                                                  (y, u) -> JᵀVP!(y,F, xk, u))

x_enzyme, _ = lsmr(opJ, -F(xk))
x_enzyme ≈ x_forward
wsmoses commented 2 weeks ago

depending on the array type, I think doing copyto! would let you do


function JVP!(y, f::F, u, v) where F 
    Enzyme.autodiff(Forward, 
        (temp, v) -> (temp .= f(v); nothing),
        Const, 
        DuplicatedNoNeed(any undef thing,, y),
        DuplicatedNoNeed(u, v))
    return nothing
end
wsmoses commented 2 weeks ago

I think the same applies for zero(w), the copy(w) however is harder

vchuravy commented 2 weeks ago

For future reference:

# https://www.aanda.org/articles/aa/full_html/2016/02/aa27339-15/aa27339-15.html
function JVP_Finite_Diff(F,u,v)
    λ = 10e-6
    δ = λ * (λ + norm(u, Inf)/norm(v,Inf))

    (F(u + δ .* v) - F(u)) ./ δ
end
vchuravy commented 2 weeks ago

"any undef thing" you mean a "Vector{Float64}(undef, 0)" would work?

wsmoses commented 2 weeks ago

I think so, in c we get away with passing a literal nullptr in these kinds of cases

amontoison commented 2 weeks ago

Do you plan to add jvp and jtvp in the API on Enzyme.jl ? Just to know if I should wait before adding an example in the documentation ok Krylov.jl.