SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
330 stars 70 forks source link

Sensitivities of a SDE w.r.t tspan #693

Open stochasticguy opened 2 years ago

stochasticguy commented 2 years ago

Hi,

I’m using this amazing package to get the sensitivities of a SDE problem. Unfortunately, the results I obtain are not correct. I know that are not correct because I can compare them with analytical solutions.

The problem can be defined by:

using DifferentialEquations
using SciMLSensitivity
using ForwardDiff
using Statistics

f(u, p, t) = p[2] * u
g(u, p, t) = p[3] * u
out(u, p)  = max(u(p[4])[1] - p[5], 0.0)

S0 = 100.
μ  = 0.02
σ  = 0.12
K  = 90.0
T  = 1.0
p  = [S0, μ, σ, T, K]
tspan = (0.0, p[4])

prob = SDEProblem{false}(f, g, S0, tspan, p)

function mean_of_solution(x)
    _prob = remake(prob; u0 = x[1], p = x, tspan=(0.0, x[4]))
    ens   = EnsembleProblem(_prob, output_func = (sol, i) -> (out(sol, x), false))
    sol   = solve(ens, EM(); dt=1/252, trajectories=100000, sensealg=ForwardDiffSensitivity())
    v     = exp(- x[2] * x[4]) * mean(sol)

    return v
end

When computing:

res = ForwardDiff.gradient(mean_of_solution, p)

I get the following values:

5-element Vector{Float64}:
  0.7980157642394161
 65.75259572802794
 41.875716241470954
  1.2615647968685262
 -0.7307577149236522

that are not correct at all.

On the other hand, when I replace the mean_of_solution method by:

function mean_of_solution1(x)
    _prob = remake(prob; u0 = x[1], p = x)
    ens   = EnsembleProblem(_prob, output_func = (sol, i) -> (out(sol, x), false))
    sol   = solve(ens, EM(); dt=1/252, trajectories=100000, sensealg=ForwardDiffSensitivity())
    v     = exp(- x[2] * x[4]) * mean(sol)

    return v
end

the results I get from computing:

res1 = ForwardDiff.gradient(mean_of_solution1, p)

are now:

5-element Vector{Float64}:
  0.865296625048412
 73.89273300510901
 21.46926842210546
  3.926001602643431
 -0.8212202504831326

this results are all correct except for res1[4] which is far from the correct value. Furthermore, res1[4] presents a big variability.

The correct results for this problem are:

5-element Vector{Float64}:
0.8653489052824384, 
73.89192667583923, 
21.673334577979716, 
-2.7782386081955677, 
-0.8210214075093247

which, as I mentioned at the beginning, can be obtained by means of closed form solutions.

Any idea or comment is welcome.

Thanks in advance!

frankschae commented 2 years ago

I can take a look. Did you check if the issue is actually specific to SDEs? I think that could even have issues for ODEs at the moment, as we mostly test/have examples for initial conditions and parameters.

See also: https://github.com/SciML/SciMLSensitivity.jl/issues/352 https://github.com/SciML/SciMLSensitivity.jl/issues/46

stochasticguy commented 2 years ago

Thanks for your answer!. Yes, unfortunately it seems it's not supported at the moment.

Do you know what would imply to implement it?

ChrisRackauckas commented 2 years ago

Is this well-defined? I am not sure. I think we might want to throw a nice error somehow mentioning that.

frankschae commented 2 years ago

IIRC, for ODEs, one can compute it by adding an additional integration (similar to the gradient with respect to the parameter)

$$ \frac{d \lambda_t}{dt} = - \lambda \frac{d f(u,p,t)}{dt}, $$

where

$\lambda$ is the adjoint state, such that $\lambda_t(t=t0) = \frac{dL}{dt}|{t=t_0}$ . The sensitivities with respect to the initial and final time points can be computed even easier I think, see https://frankschae.github.io/post/bouncing_ball/ .

ChrisRackauckas commented 2 years ago

Yes, and we should add the ODE one. But for SDEs, it's + a random process whose derivative is not defined in a point-wise way. I'm not entirely sure that exists? Its average derivative at the end point is the drift term, but I'm not sure the strong derivative exists