A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
Consider the example at doc, but one wishes to take the jacobian for all parameters, at a given time. The example works when taking the 1st element of the solution, at all times
which does work for both ForwardDiff and Zygote. However, when one wishes to calculate the jacobian for all elements but at a given time, i.e. sol[:, i] instead of sol[i, :] as in the example, ForwardDiff gives the result but Zygote can give all zeros depending on how it's accessed: full code
using OrdinaryDiffEq
using Zygote
using SciMLSensitivity, ForwardDiff
function lotka_volterra!(du, u, p, t)
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(lotka_volterra!, u0, (0.0, 10.0), p)
sol0 = solve(prob, Tsit5(), reltol=1e-6, abstol=1e-6)
function f_sol2(x, i)
_prob = remake(prob, u0=x[1:2], p=x[3:end])
sol = solve(_prob, Tsit5(), reltol=1e-6, abstol=1e-6, saveat=1)
sol[:, i]
end
u0p = [u0; p]
df_zyg = Zygote.jacobian(x -> f_sol2(x, 6), u0p)[1]
df_fwd = ForwardDiff.jacobian(x -> f_sol2(x, 6), u0p)
in this case, df_zyg is all zero, while df_fwd works. Same if one access the interpolating sol(t) interface. However, if one access the solution at given time-step with the sol.u interface, Zygote works
which is a bit strange since sol.u[i] and sol[:, i] is 'hard equal ( === )', i.e. points to the same memory sol0.u[3] === sol0[:, 3] or check with pointer(sol0.u[3])
Expected behavior
Zygote.jacobian work with sol[:, i]
Minimal Reproducible Example 👇
using OrdinaryDiffEq
using Zygote
using SciMLSensitivity, ForwardDiff
function lotka_volterra!(du, u, p, t)
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(lotka_volterra!, u0, (0.0, 10.0), p)
sol0 = solve(prob, Tsit5(), reltol=1e-6, abstol=1e-6)
function f_sol2(x, i)
_prob = remake(prob, u0=x[1:2], p=x[3:end])
sol = solve(_prob, Tsit5(), reltol=1e-6, abstol=1e-6, saveat=1)
sol[:, i]
end
function f_sol_u(x, i)
_prob = remake(prob, u0=x[1:2], p=x[3:end])
sol = solve(_prob, Tsit5(), reltol=1e-6, abstol=1e-6, saveat=1)
sol.u[i]
end
u0p = [u0; p]
df_zyg = Zygote.jacobian(x -> f_sol2(x, 6), u0p)[1]
df_fwd = ForwardDiff.jacobian(x -> f_sol2(x, 6), u0p)
df_zyg = Zygote.jacobian(x -> f_sol_u(x, 6), u0p)[1]
df_fwd = ForwardDiff.jacobian(x -> f_sol_u(x, 6), u0p)
sol0.u[3] === sol0[:, 3]
(added note)
works when returning:
@views sol[:, i]
or
sol.u[i]
wrong Zygote.jacobian when returning
sol[:, i]
or
sol(i)
Error & Stacktrace ⚠️
N/A: error is giving 0 for non-zero results
Environment (please complete the following information):
Describe the bug 🐞
Consider the example at doc, but one wishes to take the jacobian for all parameters, at a given time. The example works when taking the 1st element of the solution, at all times
which does work for both ForwardDiff and Zygote. However, when one wishes to calculate the jacobian for all elements but at a given time, i.e.
sol[:, i]
instead ofsol[i, :]
as in the example, ForwardDiff gives the result but Zygote can give all zeros depending on how it's accessed: full codein this case, df_zyg is all zero, while df_fwd works. Same if one access the interpolating
sol(t)
interface. However, if one access the solution at given time-step with thesol.u
interface, Zygote workswhich is a bit strange since sol.u[i] and sol[:, i] is 'hard equal (
===
)', i.e. points to the same memorysol0.u[3] === sol0[:, 3]
or check with pointer(sol0.u[3])Expected behavior
Zygote.jacobian work with
sol[:, i]
Minimal Reproducible Example 👇
(added note) works when returning:
or
wrong Zygote.jacobian when returning
or
Error & Stacktrace ⚠️
N/A: error is giving 0 for non-zero results
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
Additional context
Add any other context about the problem here.