A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
I'm trying to implement a discrete callback which changes the value of a parameter. I'm using ComponentArrays for the parameters as my application is for large systems with complicated parameter handling. Indexing into the component array by symbol inside the callback fails when using the ReverseDiffVJP() option during sensitivity analysis.
Expected behavior
I expect to be able to index and modify the componentarray parameters using symbol indexing.
Minimal Reproducible Example 👇
The MWE shows a version that works which is derived from an example in the tests. In the lower version, the callback attempts to modify p using the symbol :a instead of the numerical index. This causes a failure during sensitivity analysis.
using OrdinaryDiffEq, Zygote
using SciMLSensitivity, Test, ForwardDiff
#Version from the test (works)
function fiip(du, u, p, t)
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0; 1.0]
condition(u, t, integrator) = t == 5
affect!(integrator) = (integrator.p[1] = 2 * integrator.p[1] .- 0.5)
cb = DiscreteCallback(condition, affect!, save_positions = (false, false))
tstops = [5.0]
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)
du01, dp1 = Zygote.gradient(
(u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p,
callback = cb, tstops = tstops,
saveat = 0.5,
sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP()))),
u0, p)
## Version using ComponentArrays
using ComponentArrays
function fiip(du, u, p, t)
du[1] = dx = p[:a] * u[1] - p[:b] * u[1] * u[2]
du[2] = dy = -p[:c] * u[2] + p[:d] * u[1] * u[2]
end
p = ComponentArray(a = 1.5, b = 1.0, c = 3.0, d= 1.0)
u0 = [1.0; 1.0]
condition(u, t, integrator) = t == 5
affect!(integrator) = (integrator.p[:a] = 2 * integrator.p[1] .- 0.5) #DIFFERENCE HERE: Index with :a instead of 1
cb = DiscreteCallback(condition, affect!, save_positions = (false, false))
tstops = [5.0]
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)
du01, dp1 = Zygote.gradient(
(u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p,
callback = cb, tstops = tstops,
saveat = 0.5,
sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP()))),
u0, p)
Describe the bug 🐞
I'm trying to implement a discrete callback which changes the value of a parameter. I'm using ComponentArrays for the parameters as my application is for large systems with complicated parameter handling. Indexing into the component array by symbol inside the callback fails when using the ReverseDiffVJP() option during sensitivity analysis.
Expected behavior
I expect to be able to index and modify the componentarray parameters using symbol indexing.
Minimal Reproducible Example 👇 The MWE shows a version that works which is derived from an example in the tests. In the lower version, the callback attempts to modify p using the symbol :a instead of the numerical index. This causes a failure during sensitivity analysis.
Error & Stacktrace ⚠️
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()