SciML / EasyModelAnalysis.jl

High level functions for analyzing the output of simulations
MIT License
79 stars 13 forks source link

Weighted Bayesian Ensemblefits #191

Open chriselrod opened 1 year ago

chriselrod commented 1 year ago

I am running fits, hoping it converges. Haven't looked at results yet. In the following example's current form, it is too slow for CI:

using EasyModelAnalysis, LinearAlgebra

@parameters t β=0.05 c=10.0 γ=0.25
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0
∂ = Differential(t)
N = S + I + R # This is recognized as a derived variable
eqs = [∂(S) ~ -β * c * I / N * S,
    ∂(I) ~ β * c * I / N * S - γ * I,
    ∂(R) ~ γ * I];

@named sys = ODESystem(eqs);
tspan = (0,30)
prob = ODEProblem(sys, [], tspan);

@parameters t β=0.1 c=10.0 γ=0.25 ρ=0.1 h=0.1 d=0.1 r=0.1
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0 H(t)=0.0 D(t)=0.0
∂ = Differential(t)
N = S + I + R + H + D # This is recognized as a derived variable
eqs = [∂(S) ~ -β * c * I / N * S,
    ∂(I) ~ β * c * I / N * S - γ * I - h * I - ρ * I,
    ∂(R) ~ γ * I + r * H,
    ∂(H) ~ h * I - r * H - d * H,
    ∂(D) ~ ρ * I + d * H];

@named sys2 = ODESystem(eqs);

prob2 = ODEProblem(sys2, [], tspan);

@parameters t β=0.1 c=10.0 γ=0.25 ρ=0.1 h=0.1 d=0.1 r=0.1 v=0.1
@parameters t β2=0.1 c2=10.0 ρ2=0.1 h2=0.1 d2=0.1 r2=0.1
@variables S(t)=990.0 I(t)=10.0 R(t)=0.0 H(t)=0.0 D(t)=0.0
@variables Sv(t)=0.0 Iv(t)=0.0 Rv(t)=0.0 Hv(t)=0.0 Dv(t)=0.0
@variables I_total(t)

∂ = Differential(t)
N = S + I + R + H + D + Sv + Iv + Rv + Hv + Dv # This is recognized as a derived variable
eqs = [∂(S) ~ -β * c * I_total / N * S - v * Sv,
    ∂(I) ~ β * c * I_total / N * S - γ * I - h * I - ρ * I,
    ∂(R) ~ γ * I + r * H,
    ∂(H) ~ h * I - r * H - d * H,
    ∂(D) ~ ρ * I + d * H,
    ∂(Sv) ~ -β2 * c2 * I_total / N * Sv + v * Sv,
    ∂(Iv) ~ β2 * c2 * I_total / N * Sv - γ * Iv - h2 * Iv - ρ2 * Iv,
    ∂(Rv) ~ γ * I + r2 * H,
    ∂(Hv) ~ h2 * I - r2 * H - d2 * H,
    ∂(Dv) ~ ρ2 * I + d2 * H,
    I_total ~ I + Iv,
];

@named sys3 = ODESystem(eqs)
sys3 = structural_simplify(sys3)
prob3 = ODEProblem(sys3, [], tspan);

tsave = 0.0:last(tspan)
sol3 = solve(prob3, saveat =tsave);
data = [I => sol3[I], R => sol3[R]]

ensemble_priors = (
  [β=>LogNormal(0.0), γ=>LogNormal(0.0)],
  [β=>LogNormal(0.0), γ=>LogNormal(0.0), ρ=>LogNormal(0.0)],
  [β=>LogNormal(0.0), γ=>LogNormal(0.0), ρ=>LogNormal(0.0), γ=>LogNormal(0.0)],
)

p_posterior = @time bayesian_datafit((prob,prob2,prob3), ensemble_priors, tsave, data)

Weights are currently constrained to sum to 1, but they are also allowed to be negative, e.g. if you have two models, you could end up with a solution like 2 .* sol1 .- sol2. We could place it on a simplex instead, but I figured to start with, it'd be interesting to see if it does like such solutions where you subtract a result.

My concern is that this could be too much freedom/allow for creative overfitting. I won'd want to see solutions like

124242.35 .* sol1 .- 84232.298 .* sol2 .- 40009.052 .* sol3
codecov[bot] commented 1 year ago

Codecov Report

Merging #191 (f0d7d4d) into main (94e9907) will decrease coverage by 6.96%. The diff coverage is 21.90%.

@@            Coverage Diff             @@
##             main     #191      +/-   ##
==========================================
- Coverage   73.36%   66.40%   -6.96%     
==========================================
  Files           7        7              
  Lines         428      512      +84     
==========================================
+ Hits          314      340      +26     
- Misses        114      172      +58     
Impacted Files Coverage Δ
src/EasyModelAnalysis.jl 100.00% <ø> (ø)
src/datafit.jl 48.07% <21.90%> (-11.61%) :arrow_down:

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more