neurophysik / jitcdde

Just-in-time compilation for delay differential equations
Other
56 stars 14 forks source link

How to speed up integration? #36

Closed chrisajohnson closed 2 years ago

chrisajohnson commented 3 years ago

Hi,

First off, this is a really neat package that I think is exactly what I need to run DDE models for my research. Thanks!

I have a question about how to speed up the integration of my code (a minimal working example is below). Very briefly, I am using a variable-time delay model to predict the abundance of an insect species with juvenile (J) and adult (A) life stages. A third DDE equation describes juvenile survival (S) and the fourth equation describes the development time delay (tau). The maturation rate from juveniles to adults is dependent on the habitat temperature, which is modeled via a sine wave. For simplicity, I am using a constant past and I am integrating blindly (following advice in the documentation). To deal with initial discontinuities, I am simply running the model for long enough for transients to dissipate and the system to settle into equilibrium cycles.

The codes seems to be working (awesome!), but it takes 11-12 hours to run on my Mac laptop with Python 3.8.8 and JitCDDE 1.8. For reference, I ran a similar model years ago using the PyDDE package on the same laptop and the code would take ~15-20 minutes to run. I don't have access to a better computer, but I feel that I should be able to write a more efficient code that speeds up the integration. I am a Python newbie, so I am sure that any problems lie with my implementation of the code. I tried my best to follow the Common Mistakes and Questions suggestions, but must admit that some of the information about generators and Clang went over my head.

I hesitate to post this as an issue because it's very likely my own naive coding, but I am wondering if you have any suggests for speeding up the integration. Thanks so much in advance, it would be a life saver!

Best, Chris

# Delay differential equation model for predicting insect population dynamics
# under seasonal temperature variation and climate change

# IMPORT PACKAGES
from numpy import array, arange, hstack, vstack, savetxt, pi
from jitcdde import jitcdde, y, t
from symengine import exp, sin
from matplotlib.pyplot import subplots, xlabel, ylabel, xlim, ylim #, yscale
from pandas import read_csv

# INPUT TEMPERATURE RESPONSE DATA
tempData = read_csv("Temperature response data.csv")

# DEFINE MODEL PARAMETERS
# Time parameters
yr = 365.0 # days in year
max_years = 10. # how long to run simulations
tstep = 1. # time step = 1 day

# Habitat temperature parameters
meanT = 300.2
amplT = 1.36 
shiftT = 35.39

# Life history and competitive traits
# fecundity
bTopt = 8.9
Toptb = 300.4
sb = 3.64
# maturation
mTR = 0.0334
TR = 298
AmJ = 7000
skew = 1
AL = -70694
TL = 289.8
AH = 146729
TH = 308.7
# mortality
dJTR = 0.013
AdJ = 23770
dATR = 0.0265
AdA = 9710
# competition
qTR = 1
Toptq = Toptb
sq = sb
Aq = AdA
Tmax = meanT + amplT
qTemp = 1
qTopt = qTR*exp(Aq*(1./TR - 1./Tmax))

# FUNCTIONS
# Seasonal temperature variation (K) over time
def T(x):
    return meanT + amplT * sin(2*pi*(x + shiftT)/yr)

# Life history functions
# fecundity
def b(x):
    return bTopt * exp(-(T(x)-Toptb)**2/2./sb**2)

# maturation rates
def mJ(x):
    return mTR * T(x)/TR * exp(AmJ * (1./TR - 1./T(x))) / (1. + skew * (exp(AL*(1./TL-1./T(x)))+exp(AH*(1./TH-1./T(x)))))

# mortality rates
def dJ(x):
    return dJTR * exp(AdJ * (1./TR - 1./T(x)))
def dA(x):
    return dATR * exp(AdA * (1./TR - 1./T(x)))

# density-dependence due to competition
def q(x):
    return (1-qTemp) * qTR + qTemp * qTopt * exp(-(T(x)-Toptq)**2/2./sq**2)

# DDE MODEL
# Define state variables
J,A,S,τ = [y(i) for i in range(4)]

# Model
f = {
    J: b(t)*A*exp(-q(t)*A) - b(t-τ)*y(1,t-τ)*exp(-q(t-τ)*y(1,t-τ))*mJ(t)/mJ(t-τ)*S - dJ(t)*J,

    A: b(t-τ)*y(1,t-τ)*exp(-q(t-τ)*y(1,t-τ))*mJ(t)/mJ(t-τ)*S - dA(t)*A,

    S: S*(mJ(t)/mJ(t-τ)*dJ(t-τ) - dJ(t)),

    τ: 1. -  mJ(t)/mJ(t-τ)
    }

# RUN DDE SOLVER
# Time and initial conditions
times = arange(0., max_years*yr, tstep)
init = array([ 10., 1., exp(-dJ(-1e-3)/mTR), 1./mTR ])

# DDE solver
DDE = jitcdde(f, max_delay=100, verbose=False)
DDE.constant_past(init)
DDE.integrate_blindly(0.01)
DDE.compile_C(simplify=False, do_cse=False, chunk_size=30, verbose=True) # These options are used in an attempt to speed up the compiler

# Save data array containing time and state variables
data = vstack([ hstack([time,DDE.integrate(time)]) for time in times ])
#filename = 'Time_series.csv'  
#savetxt(filename, data, fmt='%s', delimiter=",", header="Time,J,A,S,tau", comments='') 

# PLOT
fig,ax = subplots()
ax.plot(data[:,0], data[:,1], label='J')
ax.plot(data[:,0], data[:,2], label='A')
#ax.plot(data[:,0], data[:,3], label='S')
#ax.plot(data[:,0], data[:,4], label='τ')
ax.legend(loc='best')
xlabel("time (days)")
ylabel("population density")
xlim((max_years-10)*yr,max_years*yr)
ylim(0,40)
#yscale("log")
#ylim(0.1,100)
Wrzlprmft commented 3 years ago

As far as I can tell at a quick glance, most time is spent by SymEngine trying to simplify your equations – and not on integration. Thus calling compile_C with simplify=False was the right call (you don’t need to set the chunk_size though as you only have four dynamical variables). However, you need to do this before calling integrate_blindly, as it already needs the compiled code. In your current code, JiTCODE first compiles inefficiently, then calls integrate_blindly, and then compiles efficiently. When I compile first, everything runs within seconds on my machine.

Some further remarks:

chrisajohnson commented 3 years ago

Awesome!!! Thank you so much for the super fast reply, it worked like a charm! Such a great package!