SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
321 stars 69 forks source link

Sensitivity analysis on tspan parameters #352

Open mattborghi opened 3 years ago

mattborghi commented 3 years ago

I am trying to differentiate a SDE function w.r.t. a tspan parameter t₀ using Zygote.jl Reverse Mode but results in wrong values.

A MWE is presented below comparing against Zygote.jl Forward Mode, Finite Differences and using SavingCallback method.

# load packages 

using DifferentialEquations
using FiniteDifferences
using DiffEqSensitivity
using ForwardDiff:Dual
using Statistics
using Zygote

# create SDE Problem 

r = 0.05
σ = 0.14
X₀ = 1.
p = [r;σ]
u0 = [X₀]
f(u,p,t) = p[1] * u
g(u,p,t) = p[2] * u
dt = 1 // 52
T = 1.0
t₀ = 0.0
tspan = (t₀, T)
prob = SDEProblem(f, g, u0, tspan, p)

# Define loss function

loss(x, solver, sensealg) = begin
    T = eltype(x)
    tmp_prob = remake(prob; tspan=(x, tspan[2]), u0=T.(u0))
    ens = EnsembleProblem(tmp_prob)
    sol = Array(solve(ens, solver; sensealg, saveat = dt, dt, trajectories=10_000))
    return sol[:,end, :] |> mean
end

Evaluating the function works fine

# call function
loss(t₀, SRIW1(), QuadratureAdjoint()) # 1.0508689728556049
loss(t₀, SOSRI(), QuadratureAdjoint()) # 1.0505429222336116
loss(t₀, EM(), QuadratureAdjoint()) # 1.0517452261120528

Calculate derivatives using Duals/Zygote Forward Mode/FD

Minor comment here: if I indicate saveat Julia crashes for Duals/ForwardDiff. Results using FD present huge variability

# using Duals
loss(Dual(t₀, 1), SRIW1(), QuadratureAdjoint()) # Dual{Nothing}(1.0615800906755797,-0.04235645034067541)
loss(Dual(t₀, 1), EM(), QuadratureAdjoint()) # Dual{Nothing}(1.0541398210078339,-0.05254337981012676)

# Zygote.jl Forward Mode
Zygote.gradient(x -> Zygote.forwarddiff(y -> loss(y, SRIW1(), QuadratureAdjoint()), x), t₀)[1] # -0.042399034428430175

# using FiniteDifferences.jl
FiniteDifferences.central_fdm(10,1)(x -> loss(x, SRIW1(), QuadratureAdjoint()), t₀) # -0.17984975695312383
FiniteDifferences.central_fdm(10,1)(x -> loss(x, SOSRI(), QuadratureAdjoint()), t₀) # 0.6018134238307663
FiniteDifferences.central_fdm(10,1)(x -> loss(x, EM(), QuadratureAdjoint()), t₀) # 0.2780011733271547
FiniteDifferences.forward_fdm(10,1)(x -> loss(x, EM(), QuadratureAdjoint()), t₀) # -221.14853371493038

Comparison against SavingCallback results differ up to a minus sign, which is expected.

# use SavingCallback
loss_callback(x, solver,sensealg) = begin
    T = eltype(x)
    tmp_prob = remake(prob; tspan=(x, tspan[2]), u0=T.(u0))
    ens = EnsembleProblem(tmp_prob)
    saved_values = SavedValues(T, T)
    cb = SavingCallback(
        (u, t, integrator) -> begin
        res = integrator(t, Val{1})
        return isempty(saved_values.saveval) ? res[1] : (saved_values.saveval + res)[1]
    end,
        saved_values;
        saveat=[0.0]
    )
    sol = solve(ens, solver; sensealg, trajectories=10_000, dt, saveat=dt, callback=cb)[:,end,:]
    return mean(sol), saved_values.saveval[end][1] / 10_000
end  

loss_callback(t₀, SRIW1(), QuadratureAdjoint()) # (1.0508679798497187, 0.04881618234432134)
loss_callback(t₀, EM(), QuadratureAdjoint()) # (1.0514073528478032, 0.05091793113552826)

Finally, using Zygote.jl Reverse Mode:

# julia crashes if we dont use saveat
Zygote.gradient(x -> loss(x, SRIW1(), ForwardDiffSensitivity()), t₀) # (nothing,)
Zygote.gradient(x -> loss(x, SOSRI(), ForwardDiffSensitivity()), t₀) # (nothing,)
Zygote.gradient(x -> loss(x, EM(), ForwardDiffSensitivity()), t₀) # (nothing,)

Thanks in advance.

ChrisRackauckas commented 3 years ago

The derivative w.r.t. tspan parameters isn't implemented even on ODEs. We just haven't thought about doing that. If you have a use case, we can keep this in mind and add it down the line though! It should just be a call to the rhs at the first and last time point right?

Results using FD present huge variability

For an SDE, you're doing the strong derivative so the derivative itself is stochastic and will change every call. Forward mode AD will work just fine for this calculation though. Finite difference would require that you provide a fixed noise process.

Comparison against SavingCallback results differ up to a minus sign, which is expected.

That works because the derivative of the mean is determinsitic even though the derivative itself is not.

Discretize-then-optimize: works only for ForwardDiffSensitivity() and the value returned is always the same.

Yeah that's just not implemented.

mschauer commented 3 years ago

The „derivative“ of X(t) with respect to t is f(X(t)) + gdW/dt by definition so you should be able to work out a closed form

ChrisRackauckas commented 3 years ago

Yeah, so it's not difficult but just needs to be implemented.

mattborghi commented 3 years ago

The derivative w.r.t. tspan parameters isn't implemented even on ODEs. We just haven't thought about doing that. If you have a use case, we can keep this in mind and add it down the line though! It should just be a call to the rhs at the first and last time point right?

The idea of derivating the tspan parameters comes from finance where we have to get

where u is the solution of the SDE. It also has several uses in many applications but this is the straightforward way.

If you want I can reference you to some bibliography where this is implemented and it can be benchmarked.

Btw, what do you mean with your last sentence?

ChrisRackauckas commented 3 years ago

Yes, that's exactly what I meant. For ODEs, it's just f(u(0)) and f(u(tend)) that gives that derivative term. For SDEs, I haven't worked out how exactly to take into account the diffusion term. But we should at least add this to the ODE part of the interface ASAP since that's easy. For ODEs this can be done by using the derivative of the interpolation which would make it free.