SciML / DiffEqBase.jl

The lightweight Base library for shared types and functionality for defining differential equation and scientific machine learning (SciML) problems
Other
302 stars 106 forks source link

Enzyme Rule Fails for DDE #1061

Open m-bossart opened 2 weeks ago

m-bossart commented 2 weeks ago

The custom enzyme rule for solve_up fails when doing sensitivity analysis for a delay differential equation problem.

I expected the same rule that works for ODEProblemto work for DDEProblem

Minimal Reproducible Example 👇

using OrdinaryDiffEq
using DelayDiffEq
using SciMLSensitivity
using Enzyme
using Zygote
using Test

## Zygote matches Enzyme for ODEProblem 

f(du, u, p, t) = du .= u .* p
u0p = [2.0, 3.0]
function f(u0p)
    prob = ODEProblem{true}(f, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Rodas4(), abstol = 1e-9, reltol = 1e-9, saveat = 0.1))
end
f(u0p)
du0p_zygote = Zygote.gradient(f, u0p)[1]
du0p = zeros(2)
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p))  
@test du0p_zygote == du0p

## Enzyme fails for DDEProblem

function f_delay(du, u, h, p, t)
     du .= u .* p .* h(p, t - 0.01)[1]
end 
h(p, t) = ones(eltype(p), 2)
u0p = [2.0, 3.0]
function f(u0p)
    prob = DDEProblem{true}(f_delay, u0p[1:1], h, (0.0, 0.2), u0p[2:2], constant_lags = [0.1])
    sum(solve(prob, MethodOfSteps(Rodas4()), abstol = 1e-9, reltol = 1e-9, saveat = 0.1))
end
f(u0p)
du0p_zygote = Zygote.gradient(f, u0p)[1]
du0p = zeros(2)
Enzyme.autodiff(Reverse, f, Active, Duplicated(u0p, du0p)) #Fails 
@test du0p_zygote == du0p

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching MixedDuplicated(::ODESolution{…}, ::ODESolution{…})

Closest candidates are:
  MixedDuplicated(::T1, ::Base.RefValue{T1}) where T1
   @ EnzymeCore C:\Users\Matt Bossart\.julia\packages\EnzymeCore\a2poZ\src\EnzymeCore.jl:163
  MixedDuplicated(::T1, ::Base.RefValue{T1}, ::Bool) where T1
   @ EnzymeCore C:\Users\Matt Bossart\.julia\packages\EnzymeCore\a2poZ\src\EnzymeCore.jl:163

Stacktrace:
 [1] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(SciMLBase.wrap_sol), df::Nothing, primal_1::ODESolution{…}, shadow_1_1::ODESolution{…})
   @ Enzyme.Compiler C:\Users\Matt Bossart\.julia\packages\Enzyme\aioBJ\src\rules\jitrules.jl:147
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

Status `C:\Users\Matt Bossart\OneDrive - UCB-O365\Desktop\Transient Stability\TestingEnzyme\Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.4.2
  [d360d2e6] ChainRulesCore v1.24.0
  [b0b7db55] ComponentArrays v0.15.13
  [bcd4f6db] DelayDiffEq v5.47.3
  [2b5f629d] DiffEqBase v6.151.4
  [7da242da] Enzyme v0.12.17
  [26cc04aa] FiniteDifferences v0.12.32
  [f6369f11] ForwardDiff v0.10.36
  [929cbde3] LLVM v7.2.1
  [8913a72c] NonlinearSolve v3.13.0
  [1dea7af3] OrdinaryDiffEq v6.84.0
  [f0f68f2c] PlotlyJS v0.18.13
  [bed98974] PowerNetworkMatrices v0.10.3
  [398b2ede] PowerSimulationsDynamics v0.14.2
  [f00506e0] PowerSystemCaseBuilder v1.2.5
  [bcd98974] PowerSystems v3.3.0
  [295af30f] Revise v3.5.14
  [0bca4576] SciMLBase v2.41.3
  [1ed8b502] SciMLSensitivity v7.61.1
  [a759f4b9] TimerOutputs v0.5.24
  [e88e6eb3] Zygote v0.6.70
Julia Version 1.10.4
Commit 48d4fd4843 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 16 × 11th Gen Intel(R) Core(TM) i7-11800H @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, tigerlake)
Threads: 1 default, 0 interactive, 1 GC (on 16 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =
m-bossart commented 2 weeks ago

@wsmoses Any thoughts on this error as it relates to the custom rule here: https://github.com/SciML/DiffEqBase.jl/blob/master/ext/DiffEqBaseEnzymeExt.jl