This issue is for supporting duplicated kwargs for custom rules as mentioned #1459.
The example I'm considering is the callback keyword argument when differentiating the solution to an ODE problem. The rule works properly when a callback is not considered in the solution:
using OrdinaryDiffEq
using SciMLSensitivity
using Enzyme
p = [3.0]
odef(du, u, p, t) = du .= u .* p
function f(p)
ode_problem = ODEProblem{true}(odef, [2.0], (0.0, 1.0), p)
sum(solve(ode_problem, Rodas5()))
end
dp = make_zero(p)
Enzyme.autodiff(Reverse, f, Active, Duplicated(p, dp))
But fails when adding the callback keyword argument due to Enzyme: No custom augmented_primal rule was applicable:
condition(u, t, integrator) = t == 0.4
function affect!(integrator)
integrator.p[1] = 0.5
end
cb = DiscreteCallback(condition, affect!)
function f(p, cb)
ode_problem = ODEProblem{true}(odef, [2.0], (0.0, 1.0), p)
sum(solve(ode_problem, Rodas5(); callback=cb, tstops=(0.4,))) #Add callback
end
dp = make_zero(p)
dcb = make_zero(cb)
Enzyme.autodiff(Reverse, f, Active, Duplicated(p, dp), Duplicated(cb, dcb))
This issue is for supporting duplicated kwargs for custom rules as mentioned #1459.
The example I'm considering is the
callback
keyword argument when differentiating the solution to an ODE problem. The rule works properly when a callback is not considered in the solution:But fails when adding the callback keyword argument due to
Enzyme: No custom augmented_primal rule was applicable
:The relevant custom rule is here: https://github.com/SciML/DiffEqBase.jl/blob/1c283e0250697047ee78327776799385c5939677/src/solve.jl#L1066C1-L1084C1
The dispatch that the rule is defined for is here: https://github.com/SciML/DiffEqBase.jl/blob/1c283e0250697047ee78327776799385c5939677/src/solve.jl#L1066