SciML / EasyModelAnalysis.jl

High level functions for analyzing the output of simulations
MIT License
81 stars 14 forks source link

Add a probability calculator for scenarios #16

Closed ChrisRackauckas closed 1 year ago

ChrisRackauckas commented 1 year ago

Given inputs of distributions, i.e. [p1 => Normal(0,1)]

ChrisRackauckas commented 1 year ago
  1. Based on the forecast, do we need interventions to keep total Covid hospitalizations under a threshold of 3000 on any given day? If there is uncertainty in the model parameters, express the answer probabilistically, i.e., what is the likelihood or probability that the number of Covid hospitalizations will stay under this threshold for the next 3 months without interventions?

  2. Based on the forecasts, do we need additional interventions to keep cumulative Covid deaths under 6000 total? Provide a probability that the cumulative number of Covid deaths will stay under 6000 for the next 6 weeks without any additional interventions.

SciMLExpectations as appropriate.

ChrisRackauckas commented 1 year ago

Needs to support SDEs for MechBayes

ArnoStrouwen commented 1 year ago

Unless we heavily restrict what the function obs can be, I don't see an easy to have the user not write the function themselves. h can probably be abstracted away.

using Pkg;
Pkg.activate(".");
using EasyModelAnalysis

@parameters t σ ρ β
@variables x(t) y(t) z(t)
D = Differential(t)

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x * (ρ - z) - y,
    D(z) ~ x * y - β * z]

@named sys = ODESystem(eqs)
sys = structural_simplify(sys)

u0 = [D(x) => 2.0,
    x => 1.0,
    y => 0.0,
    z => 0.0]

p = [σ => 28.0,
    ρ => 10.0,
    β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob)

using SciMLExpectations
function obs(sol,p)
    maximum(sol[x])>10 ? 0.0 : 1.0
end
obs(sol,prob.p)

σ_dist = truncated(Normal(28.0,1.0),20.0,40.0)
gd = GenericDistribution(σ_dist)
sm = SystemMap(prob, sol.alg)
h(x,u,p) = u,[x[1],p[2],p[3]]
exprob = ExpectationProblem(sm, obs, h, gd; nout=1)
exsol = solve(exprob, Koopman(),batch=0,quadalg = HCubatureJL())
exsol.u
julia> exsol.u
0.00039102688506400586
ArnoStrouwen commented 1 year ago

This should be enough for most scenarios?

using EasyModelAnalysis

@parameters t σ ρ β
@variables x(t) y(t) z(t)
D = Differential(t)

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x * (ρ - z) - y,
    D(z) ~ x * y - β * z]

@named sys = ODESystem(eqs)
sys = structural_simplify(sys)

u0 = [D(x) => 2.0,
    x => 1.0,
    y => 0.0,
    z => 0.0]

p = [σ => 28.0,
    ρ => 10.0,
    β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob)

using SciMLExpectations
tresholds = [x>10.0, y< -5.0]
p_prior = [σ =>  truncated(Normal(28.0,1.0),20.0,40.0), β => truncated(Normal(2.7, 0.1),2.0,4.0)]
function prob_violating_treshold(prob,p_prior,tresholds)
    pkeys = getfield.(p_prior, :first)
    p_dist = getfield.(p_prior,:second)
    gd = GenericDistribution(p_dist...)
    sol = solve(prob)
    sm = SystemMap(prob, sol.alg)
    h(x,u,p) = u, remake(prob, p = Pair.(pkeys, [x...])).p # remake does not work well with static arrays
    function g(sol,p)
        for treshold in tresholds
            if (treshold.val.f == >) || (treshold.val.f == >=)
                if maximum(sol[treshold.val.arguments[1]]) > treshold.val.arguments[2]
                     return 1.0
                end
            elseif (treshold.val.f == <) || (treshold.val.f == <=)
                if minimum(sol[treshold.val.arguments[1]]) < treshold.val.arguments[2]
                    return 1.0
                end
            else
                error()
            end
        end
        return 0.0
    end
    exprob = ExpectationProblem(sm, g, h, gd; nout=1)
    exsol = solve(exprob, Koopman(),batch=0,quadalg = HCubatureJL())
    exsol.u
end
prob_violating_treshold(prob,p_prior,tresholds)
ChrisRackauckas commented 1 year ago

Yes that looks great