SciML / JumpProcesses.jl

Build and simulate jump equations like Gillespie simulations and jump diffusions with constant and state-dependent rates and mix with differential equations and scientific machine learning (SciML)
https://docs.sciml.ai/JumpProcesses/stable/
Other
140 stars 35 forks source link

VariableRateJump - Jumps don't match expected distribution for birth-death process #320

Closed freddie090 closed 1 month ago

freddie090 commented 1 year ago

I saw that a mixed birth-death ODE-jump process I was using wasn't behaving as expected.

Following some investigation, it appears that when I switch the rates from being 'ConstantRateJump' to 'VariableRateJump' the birth-death model stops properly simulating the theoretical expectation for a birth-death process.

This issue is also only noticeable when the initial population sizes are small (eg n0 = 1.0 in the code below), so I'm assuming that this means the issue is something to do with how the time between jumps is calculated when the gap between jumps is large.

Here is a simple toy birth-death jump process model I wrote to investigate the problem (the real problem is more complicated - so although in this version the variable rates aren't necessary, they are in the full model. Also, in the full model, the rates do depend on part of the ODE solution, and hence the continuous problem has to be used).

Additionally, although this version could be modelled without the VariableRateJump, I would expect the behaviour to revert to the ConstantRateJump when the model experiences a period where the time between jumps is constant, even if the jumps become variable later on.

function grow_fxn_sjm(n0::Float64, b::Float64, d::Float64, tmax::Float64)

    u0 = [n0]
    tspan = (0.0,tmax)
    p = [b,d]

    # ODE function 
    function ode_fxn(du, u, p, t)

        b,d = p

        du .= 0
        nothing

    end

    function birth!(integrator)
        integrator.u[1] += 1 
        nothing
    end
    function death!(integrator)
        integrator.u[1] -= 1
        nothing
    end

    b_rate(u, p, t) = (u[1] * p[1])
    d_rate(u, p, t) = (u[1] * p[2])

    b_jump = VariableRateJump(b_rate, birth!)
    d_jump = VariableRateJump(d_rate, death!)

    ode_prob = ODEProblem(ode_fxn, u0, tspan, p)

    sjm_prob = JumpProblem(ode_prob, Direct(), b_jump, d_jump)

    sol = solve(sjm_prob, Tsit5())

    return sol

end

If I run this model many times with a starting population (n0) of 1.0 , I can see that the model is under-estimating the expected population size at time t, given some birth and death rates b & d. For the simple birth-death process, this should simply be n(t) = n0 * e^((b - d)*t) - if I change the b_jump and d_jump to ConstantRateJumps then the model does reach the expected population size, (nt) (when looking at the average behaviour following many iterations).

isaacsas commented 1 year ago

What are the parameters you are choosing and number of simulations you are averaging over? For a birth dominated system this model will grow exponentially so I'd expect it to have issues when the population gets big (though from what you say the issue is when the population is small). Note that the method you are using is not "exact" in the sense that Tsit5() is used to integrate the intensities in time, and a ContinuousCallback is used to root find when these integrals hit a random number (i.e. determining the next time of a given jump). The time-integration in particular will resolve based on the ODE's passed in error tolerances.

So there are a couple things you can try. Can you let me know if either of the following helps:

  1. Use Coevolve with SSAStepper instead of Direct with an ODE solver? (Note, this does require you to provide an upper bound function on the rate as described in the docs, but it should generally be a much faster method for pure variable rate problems, and should be "exact" as it avoids numerical integration and instead uses rejection sampling.)
  2. In your current code, can you try decreasing the error tolerences in the call to solve (i.e. abstol and reltol see https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/#solver_options).

That said, I tried 2., reducing the error tolerances to abstol = 1e-12 and reltol = 1e-12 and it still was underestimating when b = 2, d = 1, n0 = 1 and tspan = (0.0, 4.0), which is a case where the final population is < 100. (Though increasing my number of samples reduced the amount of underestimate, it was still noticeable even at 160,000 sample paths.) On the other hand if I pick d = 2 and b = 1 and start with a larger population I don't see any issues. So I think the problem is to do with the exponential growth when birth dominates.

@ChrisRackauckas any suggestions on how to handle this via the ODE solvers? It didn't look like one could really reduce the tolerances in the jump/callbacks anymore as the defaults are already pretty small...

freddie090 commented 1 year ago

Hi @isaacsas - thanks for looking at this.

So re 1., am I right in thinking that because in my full model the rates can depend on an ODE solution, using SSAStepper wouldn't be a permanent solution? (As I'd need it to be part of a continuous problem?)

And yes, sorry I wasn't clear which parameters I was looking at - I'm interested in a growing population (so b > d) - there is a Callback in the full model to ensure the population doesn't reach stupidly large numbers.

Because the ConstantRateJump seems to work, and the VariableRateJump is under-estimating the population size at small population sizes when b > d, I'm assuming that something about how the VariableRateJump calculates the time to the next event is over-estimating the time until the next event? And because b > d, this is realised as a smaller population size (opposed to it, say, simulating too many death events). Would that be consistent with what VariableRateJump is doing under the hood?

isaacsas commented 1 year ago

You can see what is going on here:

https://github.com/SciML/JumpProcesses.jl/blob/c8ec2ba109a85db569540d050486e18035f0f334/src/problem.jl#L265

The ODE state variables are internally augmented with the integrated intensities, and the ODE derivative function is augmented to return the current value of the intensities/rates. ContinuousCallbacks are created to determine when the integrated intensity equals a random number, which sets the next time at which that jump fires. But I'm not sure where the loss of accuracy is for your problem.

isaacsas commented 1 month ago

Closed by https://github.com/SciML/JumpProcesses.jl/issues/320