SciML / JumpProcesses.jl

Build and simulate jump equations like Gillespie simulations and jump diffusions with constant and state-dependent rates and mix with differential equations and scientific machine learning (SciML)
https://docs.sciml.ai/JumpProcesses/stable/
Other
140 stars 35 forks source link

Proposal: Abstraction for joint distribution sampling #19

Closed alanderos91 closed 6 years ago

alanderos91 commented 6 years ago

Disclaimer: I cannot speak to whether this approach is valid/generalizes to more general jumps. The discussion that follows is predicated on pure Markov jump processes (and therefore applies to ConstantRateJump).

Gillespie's paper identifies a so-called reaction probability density that serves as the basis of his simulation algorithm (Equation 18). The main work of any Gillespie algorithm is therefore sampling a random pair (τ, μ) where

I am probably not offering anything new here, but I want to highlight the three main computational steps in Gillespie-like algorithms:

  1. Sample from the joint probability distribution of (τ, μ).
  2. Update the Markov chain state.
  3. Compute the new joint probability distribution of (τ, μ).

The probability distribution changes because the jump rates change. Perhaps this is all obvious, but Step 3 is obscured in the literature and in implementations as a consequence of how the algorithms are described in articles. All the Gillespie-like algorithms can be summarized by how they implement Step 3, and so I believe what is needed is a flexible abstraction around this task.

The aggregator system with the new SSAStepper seem like they partially address this problem. I am proposing something of the following form (apologies for the bad names):

  1. Have the aggregator track both the next jump index and the next jump time. The aggregator already "encodes" the probability distribution with cur_rates, sum_rate, and rates. All we need is the index.

    mutable struct DirectJumpAggregation{T,F1,F2} <: AbstractJumpAggregator
    next_jump::Int    # the index of the next jump
    next_jump_time::T # was next_jump! sorry!
    end_time::T
    cur_rates::Vector{T}
    sum_rate::T
    rates::F1
    affects!::F2
    save_positions::Tuple{Bool,Bool}
    end
  2. Provide a function that samples from the distribution. Assume the aggregator already determined the next random pair:

    # not really sampling...
    sample_distribution(p::DirectJumpAggregation) = p.next_jump_time, p.next_jump
  3. Provide a function that determines the next random pair. This is what distinguishes all the SSAs. New aggregators should implement. For the existing aggregator:

    # maybe called `update_pair`?
    function update_distribution!(p, integrator)
    rng_val = rand()
    
    p.sum_rate, ttnj = time_to_next_jump(integrator.u,integrator.p,integrator.t,p.rates,p.cur_rates)
    p.next_jump = searchsortedfirst(p.cur_rates,rng_val)
    
    p.next_jump_time = integrator.t + ttnj
    if p.next_jump_time < p.end_time
      add_tstop!(integrator,p.next_jump_time)
    end
    end
  4. Rewrite the current affect into a general procedure:

    function (p::DirectJumpAggregation)(integrator) # affect!
    # 1. obtain the next random pair
    τ, μ = sample_distribution(p)
    
    # 2. update simulation by applying the affects for each jump
    @inbounds p.affects![i](integrator)
    
    # 3. update the jump rates and determine the next random pair
    update_distribution!(p, integrator)
    
    nothing
    end

The benefits of this "abstraction" are:

  1. New algorithms are new aggregators that implement an update_distribution! method. That's it.

  2. The sample_distribution function does not care whether the algorithm couples or decouples the random pair. This is all handled by the update procedure, where it actually matters.

  3. It may be possible to provide some default algorithms for some of the subtasks in "updating the distirbution". For example, most papers (and implementations) determine the next jump index using a linear search rather than bisection. This requires that cur_rates is a vector of the jump rates rather than a cumulative sum, but that's an implementation detail specific to a particular SSA. The point is, new SSAs can be built up by combining the smaller algorithms for specific tasks.

  4. If (3) is possible, then we get new algorithms for free!

Relevant: #16

alanderos91 commented 6 years ago

Some details omitted in the first post:

Markov chain theory allows one to cook up all sorts of recipes for sampling (τ, μ):

ChrisRackauckas commented 6 years ago

Have the aggregator track both the next jump index and the next jump time. The aggregator already "encodes" the probability distribution with cur_rates, sum_rate, and rates. All we need is the index.

Just make a new aggregator that has the index tracked. There is no "the aggregator". Each aggregator is simply one implementation of one method, and the whole point of the interface is that there can be choices other than Direct() (though right now there are not).

But yeah,

New algorithms are new aggregators that implement an update_distribution! method. That's it.

Sure, a set of helper functions that make it easy to implement new aggregators (and encourage code re-use) is highly appreciated! It sounds like you have a plan for that and I think that it sounds great.

isaacsas commented 6 years ago

This interface would work nicely I think, except in the aggregate routine. Currently the Direct solver is initialized there too (i.e. a next jump time is calculated), where there is no integrator provided. You could either change to

update_distribution!(p, u, params, t)

or maybe the calls to calculate the next reaction time are not needed in aggregate and could be removed? @ChrisRackauckas can you comment on the necessity of the latter? I don't think SSAStepper needs them, but I'm not sure about FunctionMaps or other contexts.

ChrisRackauckas commented 6 years ago

or maybe the calls to calculate the next reaction time are not needed in aggregate and could be removed?

Oh, they should be removed. They are re-initialized with every solve anyways

https://github.com/JuliaDiffEq/DiffEqJump.jl/blob/master/src/aggregators/direct.jl#L28

because otherwise if you re-solve a jump problem you get screwy results. So those values are actually always ignored, so that computation should be removed.

alanderos91 commented 6 years ago

Just make a new aggregator that has the index tracked. There is no "the aggregator". Each aggregator is simply one implementation of one method, and the whole point of the interface is that there can be choices other than Direct() (though right now there are not).

Agreed. My only concern is how this plays out with specialized mass action jumps @isaacsas is implementing. It's certainly nice to have Direct(), DirectWithFWrappers(), and so on. These aggregator algorithms should generate aggregators where the F1 and F2 type parameters differ.

Can we require new aggregators to provide certain fields (e.g. next_jump and time_to_next_jump)?

ChrisRackauckas commented 6 years ago

Can we require new aggregators to provide certain fields (e.g. next_jump and time_to_next_jump)?

Yeah, we can make the interface more constrained. I don't know enough about all of the various SSAs to make the determination on how flexible of an interface it needs to be, so my interface was just that the aggregator should have one function aggregate(aggregator::Direct,u,p,t,end_time,constant_jumps,save_positions) that returns a callback. If there's something that's easier to work with but still allows all of the fancy SSA designs then we should do that, or at least make some behind-the-scenes helper functions for implementing aggregators.

isaacsas commented 6 years ago

One possibility would be to have a new abstract type

SSAJumpAggregator <: AbstractJumpAggregator

and require those fields for it. Then one could write some of the interface to work over all SSAJumpAggregators (i.e. the affect functor and condition functors). At the end of the day this seems like it's going to require users to implement the JumpAggregation type, the initialization functor, aggregate, and an update_samplestype routine. The rest could be generic I think. It's not substantially reducing the number of routines that must be filled in, but is perhaps making clearer what those routines should do.

isaacsas commented 6 years ago

Are there any use cases for jump models where more than one jump could be execute at a time? In that case it may be better to specialize for SSA models (which only have one jump per event) and not force that on all aggregators.

alanderos91 commented 6 years ago

Are there any use cases for jump models where more than one jump could be execute at a time?

In the sense that the jumps happen at the same time t: I don't think so. I think such events only occur on sets of zero probability for Poisson processes. In the sense that the jumps happen at the same simulation step: only tau-leaping comes to mind.

I agree that we should specialize for SSA.

ChrisRackauckas commented 6 years ago

One possibility would be to have a new abstract type

I think this makes sense. Remember, this isn't really a user-level abstraction. It's a dev abstraction since most users won't implement an SSA. I think an abstract type with a simplified implementation strategy, while leaving the actual interface just aggregate -> callback is flexible enough for research but gives us an easy way to implement all of the normal SSAs (this is pretty much how the DE solvers are done. Technically anyone can add an algorithm, but in practice most are from the *DiffEq packages which have a lot more structure). I'll let you guys design that subtype since you guys know the other SSAs a lot better than me.

isaacsas commented 6 years ago

One thing that would be nice to get into the interface as long as it is being updated is the ability to specify the random number generator. For SSAs, libraries like Random123 can be much better since they allow the generation of multiple statistically strong RN sequences. This is useful if one is running many copies of an SSA simulation in parallel. Could we make the RNG an optional parameter to SSAJumpAggregators?

ChrisRackauckas commented 6 years ago

Yes. It should just use RandomNumbers.jl like DiffEqNoiseProcess.jl does.

isaacsas commented 6 years ago

OK, I'll take a look at how the interface works there then.

alanderos91 commented 6 years ago

Awesome. I think I have a clear enough idea to put together a PR. I'll touch base again once I have something running to handle optional RNGs.

isaacsas commented 6 years ago

Cool, I'm looking forward to it! Once the interface is straightened out I can put through the mass action reaction jump type and start submitting SSAs that include it...

One possible way to give the user RNG selection freedom would be to make it an optional named argument to the JumpProblem constructor, and then pass it into aggregate from the constructor. That would be a pretty minimal way to get it into each JumpAggregation SSA.

alanderos91 commented 6 years ago

I've run into a few problems that require design choices:

  1. It's not possible to define functors on abstract types. Namely, (p::SSAJumpAggregator)(integrator) = ... throws an error even though retrieve_jump(p::SSAJumpAggregator) = ... is allowed. This is an open issue. It's unfortunate because this is precisely what I wanted to target. Worst case everyone just copy-pastes the function body of one SSAJumpAggregator until we find a suitable workaround/the issue is resolved?

  2. I'm including some helper functions for (i) updating rates, (ii) sampling jump times, and (iii) sampling jump indices. Generating new jumps looks like this:

    function generate_jump!(p::DirectJumpAggregation,u,params,t)
    # update the jump rates
    sum_rate = cur_rates_as_cumsum(u,params,t,p.rates,p.cur_rates)
    # determine next jump index
    i = randidx_bisection(p.cur_rates, rand(p.rng))
    # determine next jump time
    ttnj = randexp_ziggurat(p.rng,sum_rate)
    # mutate fields
    p.sum_rate = sum_rate
    p.next_jump = i
    p.next_jump_time = t + ttnj
    nothing
    end

    The issue is whether the helper functions should depend on an aggregator or fields/values generated outside the function call. As an example:

    
    # aggregator independent
    @inline randexp_ziggurat(rng,sum_rate) = randexp(rng) / sum_rate

aggregator dependent

@inline randexp_ziggurat(p::SSAJumpAggregator) = randexp(p.rng) / p.sum_rate

This example is trivial; I'm just wrapping a function call. Algorithms that couple the time to the next jump and random index are a bit more subtle:
```julia
function randjump_firstjump_ziggurat(rng,cur_rates)
  ttnj = typemax(eltype(cur_rates))
  i = 0
  for j in eachindex(cur_rates)
    dt = randexp(rng) / cur_rates[j]
    if dt < ttnj
      ttnj = dt
      i = j
    end
  end
  return ttnj,i
end

The Next Reaction Method introduces additional data structures. I think these helper functions are only helpful if they relieve some of the burden in figuring what information to plug-in. Any guidance on this matter is much appreciated.

isaacsas commented 6 years ago

I'm fine with having a "template" type SSA body that gives recommended work flow.

For the rng; I would just like a way to have the rng passed down to the jumpaggregation (through aggregate?). I'm fine with calling randexp(p.rng) manually within each SSA instead of having a helper function. The helper routine is just rescaling to set the correct parameter for the exponential random variable, so if you keep it maybe you could just call it randexp_scaled.

generatejump! seems like a good way to update the state -- I've basically converted to using an equivalent update_samples routine.

Even barring the first issue you mention, it may be tough to do this is a fully generic way. If one is using constant_jumps the method for executing a jump is through the appropriate affects! function, while with massaction_jumps there is a single global function to call with the reaction id (execturerx!). Juggling both simultaneously adds a bit more work (and I think for more sophisticated SSAs I will only support mass action since we'll need dependency graphs and such).

@alanderos91 Maybe take a look at what I've been playing with. I think I got the (function wrapped) Direct method version pretty clean at the moment, see Function Wrapper Direct Method. A version that supports constant jumps and mass action is: here (Note, for the latter I haven't actually tested a hybrid system with mass action and constant jumps yet -- I only just put it together earlier today.)

isaacsas commented 6 years ago

Looking at this more, I think what you pass into generate_jump is the right set of stuff. You could then have an analogous execute_jump method. Then the only thing left to do is to create a function to handle setting the tstopfor all SSA aggregators, and perhaps an initialization function. All the functors could then be copied/pasted from a common template. As you said different SSAs will need different info (some need tree structures, some need table structures,...). This can't be handled generically, unless we assume it is stored in the JumpAggregation struct and passed around.

ChrisRackauckas commented 6 years ago

Worst case everyone just copy-pastes the function body of one SSAJumpAggregator until we find a suitable workaround/the issue is resolved?

Yes, let's do this for now and I hope it just gets fixed.

alanderos91 commented 6 years ago

For the rng; I would just like a way to have the rng passed down to the jumpaggregation (through aggregate?).

Already done! JumpProblem takes rng as a named argument and propagates that to an aggregator as you suggested. It defaults to Base.GLOBAL_RNG and is stored as a field in an aggregator.

The helper routine is just rescaling to set the correct parameter for the exponential random variable, so if you keep it maybe you could just call it randexp_scaled.

I agree, and that's the main reason I'm partial to having the helper functions tied to an aggregator. I don't think I will keep these helper functions, but still believe there is something more to do here. I'll hold off on thinking about this so much for now; things will become clearer once we can look at a suite of SSAs in the package.

As you said different SSAs will need different info (some need tree structures, some need table structures,...). This can't be handled generically, unless we assume it is stored in the JumpAggregation struct and passed around.

I think the information should be stored by a JumpAggreation. Our generate_jump! and update_samples functions should act as the playground for SSA development. All the other required functions/functors are the means to this end. In this way a specific implementation can opt-in to using certain features and data structures like dependency graphs, trees, and so on.

alanderos91 commented 6 years ago

Looking at your implementations I think we are in agreement. Here's a preview. Rather than providing a new JumpAggregation for implementations that use function wrappers or some other structure, might it be possible to generate these from the aggregator algorithms?

For example:

struct DirectFunWrap <: AbstractAggregatorAlgorithm end
struct DirectOtherWay <: AbstractAggregatorAlgorithm end

function aggregate(aggregator::DirectFunWrap,...)
  # build the rates and affects using get_jump_info_fwrappers
  DirectJumpAggregation(next_jump,next_jump_time,end_time,cur_rates,
    sum_rate,rates,affects!,save_positions,rng)
end

function aggregate(aggregator::DirectOtherWay,...)
  # build the rates and affects using some other way
  DirectJumpAggregation(next_jump,next_jump_time,end_time,cur_rates,
    sum_rate,rates,affects!,save_positions,rng)
end

Then we should just need update_samples to specialize on the type parameters for the rates and affects.

ChrisRackauckas commented 6 years ago

JumpProblem takes rng as a named argument and propagates that to an aggregator as you suggested. It defaults to Base.GLOBAL_RNG and is stored as a field in an aggregator.

I've found that it's pretty much universally better to default to building an Xorshifts.Xoroshiro128Plus(rand(UInt64)) from RandomNumbers.jl. Not only is this faster, but by not relying on the global rng this implementation is multithreading-safe by default.

Rather than providing a new JumpAggregation for implementations that use function wrappers or some other structure, might it be possible to generate these from the aggregator algorithms?

I agree with this.

isaacsas commented 6 years ago

Rather than providing a new JumpAggregation for implementations that use function wrappers or some other structure, might it be possible to generate these from the aggregator algorithms?

So you'll setup dispatched versions of generate_jumps! based on the type of rates in the JumpAggregation? The time to next jump calculation would work differently for tuples (recursion) vs function wrappers (a loop). If this is obvious I apologize; I'm still pretty new to Julia. EDIT: Just saw that you actually said exactly this :)

I'd also like if we expand the JumpAggregation for SSAs aggregators to include a ma_jumps field like my second gist above. I think for anyone interested in SSAs that involve reactions this will be the preferred jump type since it is much quicker. I've got jumps.jl and problems.jl all updated to handle a fourth jump type, along with a direct method that can use both. I can submit a PR with them once we get the interface straightened out and I update to it.

If you'll allow me one more gist, we could take your idea of moving to a more function-based interface to the extreme. See here. Basically the functors are all just copy and pasted over, and could be made generic once language support is added (and hence removed from every SSA). For a new SSA we just need to implement: aggregate, initialize!, execute_jumps!, generate_jumps!. register_next_jump_time! sets the tstop condition and can be made generic over all SSA algorithms (and is only used within the functor templates). I think this is actually more along the lines of how you did things in BioSimulator?

alanderos91 commented 6 years ago

I'd also like if we expand the JumpAggregation for SSAs aggregators to include a ma_jumps field like my second gist above.

Do you need this to be a new abstract type or built in directly to our SSAJumpAggregator? I think the latter is better.

isaacsas commented 6 years ago

I was just building it into the JumpAggregator

mutable struct DirectMAJumpAggregation{T,S,F1,F2,V} <: AbstractSSAJumpAggregator
    next_jump_time::T   
    next_jump::Int64     
    end_time::T              
    cur_rates::Vector{T} 
    sum_rate::T
    ma_jumps::S
    rates::F1
    affects!::F2
    save_positions::Tuple{Bool,Bool}
    rng::V 
end

As you said, that is the natural place to stash additional info.

isaacsas commented 6 years ago

@ChrisRackauckas Just a heads up, there is at least one claim that Xoroshiro128+ might have some issues (not that MT is great either): here.

ChrisRackauckas commented 6 years ago

What about star?

isaacsas commented 6 years ago

In her testing she says it passes everything she's looked at and recommends "XorShift 128/64 (i.e., high 64 bits of 128-bit XorShift)". It's recommend on her blog at here and here.

ChrisRackauckas commented 6 years ago

Alright, we can just change to using those as the default since RandomNumbers.jl has them and they benchmark better in the latest tests.

https://sunoru.github.io/RandomNumbers.jl/latest/man/benchmark/

isaacsas commented 6 years ago

I was mentioning XorShift, which is different than Xoroshiro128. I can't seem to find any info about whether the latter fails the same test battery that Xoroshiro128+ fails (the PractRand tests).

ChrisRackauckas commented 6 years ago

Oh okay. Xoroshiro128 should be much more robust than XorShift*.

alanderos91 commented 6 years ago

Closed by #20