Open gdalle opened 2 months ago
Here's an MWE:
using Enzyme
using Enzyme.EnzymeCore: ReverseModeSplit
function vjp(
rmode::ReverseModeSplit{ReturnPrimal},
dresult,
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresult, tape))
else
shadow_result .+= dresult # TODO: generalize beyond arrays
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
As discussed on Slack, it is sometimes annoying to use reverse mode because you need
autodiff_thunk
whenever1.0
A solution suggested by @wsmoses would be syntactic sugar for
vjp
(and maybejvp
but that is less necessary causeautodiff
gets the job done every time). There are a few design questions around this:dreturn
together with the return activity?Duplicated
returns, we would have to specifydreturn = Duplicated(whatever, actual_thing_we_care_about)
.vjp
? How do we pass several cotangents if the return value isActive
?