EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Support for duplicated kwargs for custom rules #1491

Open m-bossart opened 1 month ago

m-bossart commented 1 month ago

This issue is for supporting duplicated kwargs for custom rules as mentioned #1459.

The example I'm considering is the callback keyword argument when differentiating the solution to an ODE problem. The rule works properly when a callback is not considered in the solution:

using OrdinaryDiffEq
using SciMLSensitivity
using Enzyme

p = [3.0]
odef(du, u, p, t) = du .= u .* p
function f(p)
    ode_problem = ODEProblem{true}(odef, [2.0], (0.0, 1.0), p)
    sum(solve(ode_problem, Rodas5()))  
end
dp = make_zero(p)
Enzyme.autodiff(Reverse, f, Active, Duplicated(p, dp)) 

But fails when adding the callback keyword argument due to Enzyme: No custom augmented_primal rule was applicable:

condition(u, t, integrator) = t == 0.4
function affect!(integrator)
    integrator.p[1] = 0.5
end
cb = DiscreteCallback(condition, affect!)
function f(p, cb)
    ode_problem = ODEProblem{true}(odef, [2.0], (0.0, 1.0), p)
    sum(solve(ode_problem, Rodas5(); callback=cb, tstops=(0.4,)))  #Add callback 
end
dp = make_zero(p)
dcb = make_zero(cb)
Enzyme.autodiff(Reverse, f, Active, Duplicated(p, dp), Duplicated(cb, dcb))  

The relevant custom rule is here: https://github.com/SciML/DiffEqBase.jl/blob/1c283e0250697047ee78327776799385c5939677/src/solve.jl#L1066C1-L1084C1

The dispatch that the rule is defined for is here: https://github.com/SciML/DiffEqBase.jl/blob/1c283e0250697047ee78327776799385c5939677/src/solve.jl#L1066