SciML / JumpProcesses.jl

Build and simulate jump equations like Gillespie simulations and jump diffusions with constant and state-dependent rates and mix with differential equations and scientific machine learning (SciML)
https://docs.sciml.ai/JumpProcesses/stable/
Other
136 stars 35 forks source link

Memory allocations when calling FunctionWrapped Symbolic affect functions #308

Closed isaacsas closed 1 year ago

isaacsas commented 1 year ago
using Catalyst, JumpProcesses 

function catalyst_plasmid(;T::Float64=250.0, alg = Direct())    
           function rate(η,X,Y,K)
               return η*(1-(X+Y)/K)
           end

           fp_model = @reaction_network begin
               rate(η,F,P,K), F --> 2F   # Background reproduction F
               rate(η,P,F,K), P --> 2P   # Background reproduction P  !NOTE: order of F & P are switched
               μ, F --> 0                # Background mortality F
               μ, P --> 0                # Background mortality P
               γ, F + P --> 2P           # Infection (conjugation)
               ρ, P --> F                # Recovery (segregation error)
           end    
           p = (:η => 1.0, :μ => 0.1, :γ => 1e-5, :ρ => 0.01, :K => 1e4)
           u0 = [:F => 100, :P => 10]
           tspan = (0.0, T)

           dprob = DiscreteProblem(fp_model, u0, tspan, p)
           jump_prob = JumpProblem(fp_model, dprob, alg; save_positions=(false,false))
           return jump_prob
       end
julia> short_prob = catalyst_plasmid(T=1e3, alg = RSSA())
JumpProblem with problem DiscreteProblem with aggregator RSSA
Number of jumps with discrete aggregation: 2
Number of jumps with continuous aggregation: 0
Number of mass action jumps: 4

julia> @btime solve($short_prob, $stepper)
  166.378 ms (903165 allocations: 13.82 MiB)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 2-element Vector{Float64}:
    0.0
 1000.0
u: 2-element Vector{Vector{Int64}}:
 [100, 10]
 [918, 8077]

julia> long_prob = catalyst_plasmid(T=1e5, alg = RSSA())
JumpProblem with problem DiscreteProblem with aggregator RSSA
Number of jumps with discrete aggregation: 2
Number of jumps with continuous aggregation: 0
Number of mass action jumps: 4

julia> @btime solve($long_prob, $stepper)
  16.709 s (90009068 allocations: 1.34 GiB)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 2-element Vector{Float64}:
      0.0
 100000.0
u: 2-element Vector{Vector{Int64}}:
 [100, 10]
 [998, 8034]

From what I can tell this is coming from the FunctionWrappers wrapped affect functions created at

https://github.com/SciML/JumpProcesses.jl/blob/888a36b86118f008dc1018871ef4715eb1d60324/src/jumps.jl#L691-L698

with the allocations appearing at

https://github.com/SciML/JumpProcesses.jl/blob/888a36b86118f008dc1018871ef4715eb1d60324/src/aggregators/ssajump.jl#L184

where an affect function within this vector is called.

This behavior seems to only occur with RuntimeGeneratedFunctions coming via MTK and Symbolics. Hand-coded affect functions do not seem to show the same allocations.

I tried dropping the anonymous function that wraps the user affect function and returns nothing, but while that did seem to reduce allocations it didn't eliminate the issue. The only other thing I can think of is that perhaps the function input type in the function wrapper shouldn't be Any, but then we'd have to defer constructing the wrappers until after the integrator has been created somehow (which is a bit circular since the integrator ultimately stores the aggregation that holds the wrappers). I also don't see why we'd only have an issue with RuntimeGeneratedFunctions from making the input type Any.

I'm not sure how to proceed further on investigating this as it is getting a bit more low level than my current understanding (i.e. I don't really know much about how either RuntimeGeneratedFunctions or FunctionWrappers work under the hood).

isaacsas commented 1 year ago

Note that the new version of Direct, which uses tuples and manual splitting, has no such issues.

isaacsas commented 1 year ago

Should be fixed by https://github.com/SciML/JumpProcesses.jl/pull/309