EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Syntactic sugar for `vjp` #1853

Open gdalle opened 4 days ago

gdalle commented 4 days ago

As discussed on Slack, it is sometimes annoying to use reverse mode because you need autodiff_thunk whenever

A solution suggested by @wsmoses would be syntactic sugar for vjp (and maybe jvp but that is less necessary cause autodiff gets the job done every time). There are a few design questions around this:

gdalle commented 4 days ago

Here's an MWE:

using Enzyme
using Enzyme.EnzymeCore: ReverseModeSplit

function vjp(
    rmode::ReverseModeSplit{ReturnPrimal},
    dresult,
    f::FA,
    ::Type{RA},
    args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
    forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
    tape, result, shadow_result = forward(f, args...)
    if RA <: Active
        dinputs = only(reverse(f, args..., dresult, tape))
    else
        shadow_result .+= dresult  # TODO: generalize beyond arrays
        dinputs = only(reverse(f, args..., tape))
    end
    if ReturnPrimal
        return (dinputs, result)
    else
        return (dinputs,)
    end
end
Tests ```julia f(x, y) = x^2 + 10y # scalar output g(x, y) = [f(x, y)] # vector output x, y = 3.0, 4.0 dz = -5.0 z0 = f(x, y) dx0 = dz * 2x dy0 = dz * 10 ``` ```julia julia> ((dx, dy), z) = vjp(ReverseSplitWithPrimal, dz, Const(f), Active, Active(x), Active(y)) ((-30.0, -50.0), 49.0) julia> z == z0 && dx == dx0 && dy == dy0 true julia> ((dx, dy), vz) = vjp( ReverseSplitWithPrimal, [dz], Const(g), Duplicated, Active(x), Active(y) ) ((-30.0, -50.0), [49.0]) julia> vz == [z0] && dx == dx0 && dy == dy0 true ```