SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.85k stars 226 forks source link

Parameter estimation with multiple sets of initial conditions #234

Closed mcfefa closed 6 years ago

mcfefa commented 6 years ago

Hi @ChrisRackauckas, I am trying to use the http://docs.juliadiffeq.org/latest/analysis/parameter_estimation.html examples to find parameters for a system of ODEs in a biology problem. I am able to solve for a single "run" or "trial" of my experiments. But I want to fit one set of parameters across multiple runs.

That is, each row of my data corresponds to a single set of initial conditions, but all have the same parameters of the ODEs. I can find the parameter set for one row of my data, but I'm having trouble figuring out how I would go about fitting one set of parameters for multiple rows/experimental conditions? Is there an example of how to do this somewhere?

Thanks! Meghan

ChrisRackauckas commented 6 years ago

Let's do this step by step.

First you want to create a problem which solves multiple problems at the same time. This is the Monte Carlo Problem. When the parameter estimation tools say it will take any DEProblem, it really means ANY DEProblem!

So, let's get a Monte Carlo problem setup that solves with 10 different initial conditions. We define our full ODE as follows (I'm just going to use Lotka-Volterra since I'm a biologist too!)

using DifferentialEquations
function pf_func(t,u,p,du)
  du[1] = p[1] * u[1] - p[2] * u[1]*u[2]
  du[2] = -3 * u[2] + u[1]*u[2]
end
pf = ParameterizedFunction(pf_func,[1.5,1.0])
prob = ODEProblem(pf,[1.0,1.0],(0.0,10.0))

Now for a MonteCarloProblem we have to take this problem and tell it what to do N times via the prob_func. So let's generate N=10 different initial conditions, and tell it to run the same problem but with these 10 different initial conditions each time:

function prob_func(prob,i,repeat)
  ODEProblem(prob.f,initial_conditions[i],prob.tspan)
end
monte_prob = MonteCarloProblem(prob,prob_func=prob_func)

We can check this does what we want by solving it:

sim = solve(monte_prob,Tsit5(),num_monte=N)
using Plots; plotly()
plot(sim)

newplot

num_monte=N means "run N times", and each time it runs the problem returned by the prob_func, which is always the same problem but with the ith initial condition.

Now let's generate a dataset from that. Let's get data points at every t=0.1 using saveat, and then convert the solution into an array.

data_times = 0.0:0.1:10.0
sim = solve(monte_prob,Tsit5(),num_monte=N,saveat=data_times)
data = convert(Array,sim)

Here, data[i,j,k] is the same as sim[i,j,k] which is the same as sim[k][i,j] (where sim[k] is the kth solution). So data[i,j,k] is the jth timepoint of the ith variable in the kth trajectory.

Now let's build a loss function. A loss function is some loss(sol) that spits out a scalar for how far from optimal we are. In the documentation I show that we normally do loss = L2Loss(t,data), but we can bootstrap off of this. Instead lets build an array of N loss functions, each one with the correct piece of data.

losses = [L2Loss(data_times,data[:,:,i]) for i in 1:N]

So losses[i] is a function which computes the loss of a solution against the data of the ith trajectory. So to build our true loss function, we sum the losses:

loss(sim) = sum(losses[i](sim[i]) for i in 1:N)

As a double check, make sure that loss(sim) outputs zero (since we generated the data from sim). Now we generate data with other parameters:

pf = ParameterizedFunction(pf_func,[1.2,0.8])
prob = ODEProblem(pf,[1.0,1.0],(0.0,10.0))
function prob_func(prob,i,repeat)
  ODEProblem(prob.f,initial_conditions[i],prob.tspan)
end
monte_prob = MonteCarloProblem(prob,prob_func=prob_func)
sim = solve(monte_prob,Tsit5(),num_monte=N,saveat=data_times)
loss(sim)

and get a non-zero loss. So we now have our problem, our data, and our loss function... we have what we need.

Put this into build_loss_objective.

obj = build_loss_objective(monte_prob,Tsit5(),loss,num_monte=N,
                           saveat=data_times)

Notice that I added the kwargs for solve into this. They get passed to an internal solve command, so then the loss is computed on N trajectories at data_times.

Thus we take this objective function over to any optimization package. I like to do quick things in Optim.jl. Here, since the Lotka-Volterra equation requires positive parameters, I use Fminbox to make sure the parameters stay positive. The optimization command is:

using Optim
lower = zeros(2)
upper = fill(3.0,2)
result = optimize(obj, [1.3,0.9], lower, upper, Fminbox{BFGS}())

I start the optimization with [1.3,0.9], and Optim spits out that the true parameters are:

Results of Optimization Algorithm
 * Algorithm: Fminbox with BFGS
 * Starting Point: [1.3,0.9]
 * Minimizer: [1.5000000004217302,0.9999999987018824]
 * Minimum: 1.482579e-13
 * Iterations: 4
 * Convergence: true
   * |x - x'| < 1.0e-32: true 
     |x - x'| = 0.00e+00 
   * |f(x) - f(x')| / |f(x)| < 1.0e-32: true
     |f(x) - f(x')| / |f(x)| = 0.00e+00 
   * |g(x)| < 1.0e-08: false 
     |g(x)| = 1.20e-03 
   * Stopped by an increasing objective: true
   * Reached Maximum Number of Iterations: false
 * Objective Calls: 108
 * Gradient Calls: 108

[1.5,0.99999]... close enough. So Optim found the true parameters.

I would run a test on synthetic data for your problem before using it on real data. Maybe play around with different optimization packages, or add regularization. You may also want to decrease the tolerance of the ODE solvers via

obj = build_loss_objective(monte_prob,Tsit5(),loss,num_monte=N,
                           abstol=1e-8,reltol=1e-8,
                           saveat=data_times)

if you suspect error is the problem. However, if you're having problems it's most likely not the ODE solver tolerance and mostly because parameter inference is a very hard optimization problem.

Let me know if this helps. If it did, I can add a notebook on this to DiffEqTutorials.jl (or if you're up to it, it would make a great PR!). Let me know if there's anything that needs more explanation.

mcfefa commented 6 years ago

Thanks @ChrisRackauckas this was so helpful! I am getting results now, but you're right that parameter estimation is hard.

I opened pr https://github.com/JuliaDiffEq/DiffEqTutorials.jl/pull/16 with this example

ChrisRackauckas commented 6 years ago

Great. Let me know if you need anything else. Over the next year and next summer we'll be building a lot of new tools using expectation maximization and Bayesian methods at DiffEqBayes.jl. If you can't find an optimizer that works, hopefully we'll have something else sooner rather than later. Take care.