neurophysik / jitcdde

Just-in-time compilation for delay differential equations
Other
56 stars 14 forks source link

How to speed up setting parameters? #38

Closed MichaelLLi closed 2 years ago

MichaelLLi commented 2 years ago

Hi:

I have a question relating to whether or not there is a way to speed up the function of setting parameters. I have followed the tutorial and the example code within #30 to setup the DDE for outer parameter optimization. The hopefully relevant part of the model code is as followed:

import numpy as np
from jitcdde import jitcdde,y,t, jitcdde_input, input
import sympy as sp
from ttictoc import tic,toc

IncubeD = 5
RecoverID = 10
RecoverHD = 15
DetectD = 7
maxT = 300
t_delay = 14

S,V,E,EV,I,IV,IQ,ID,IU,R,D,DC = [y(i) for i in range(12)]
IQd = y(6,t-t_delay)
IDd = y(7,t-t_delay)
control_pars = [alpha, r_gov, p_und, p_dth, r_dth, d_decay, p_d, k1, k2, k3] = symbols("alpha rgov pund pdth rdth ddecay pd k1 k2 k3")
gamma_t = (
    sp.Max(1 - r_gov * sp.Max(sp.log(IQd + IDd + 1)/ (np.log(N)), 0), 0)
)            
r_i = np.log(2) / IncubeD  # Rate of infection leaving incubation phase
r_d = np.log(2) / DetectD  # Rate of detection
r_ru = np.log(2) / RecoverID  # Rate of recovery not under infection
r_rq = np.log(2) / RecoverHD  # Rate of recovery under hospitalization
r_decay = 1 / d_decay  
t_predictions = [i for i in range(maxT)]
input_spline = CubicHermiteSpline.from_data(np.array(t_predictions),np.zeros(((len(t_predictions),1)))

model_long_covid = {
    S: -alpha * gamma_t * S * (I + p_und * IU)  / N - input(0) + r_decay * V + r_decay * R,
    V: input(0) - alpha * (1 - beta) * gamma_t * V * (I + p_und * IU) / N - r_decay * V,
    E:  alpha * gamma_t * S * (I + p_und * IU)  / N  - r_i * E,
    EV: alpha * (1 - beta) * gamma_t * V * (I + p_und * IU)  / N - r_i * EV,
    I: r_i * E - r_d * I,
    IV: r_i * EV - r_d * IV,
    IQ: p_d * (r_d * IV  + r_d * (1 - p_dth) * I) - r_rq * IQ,
    ID: r_d * p_dth * I - r_dth * ID,
    IU: (1 - p_d) * (r_d * IV  + r_d * (1 - p_dth) * I) - r_ru * IU,
    R: r_rq * IQ + r_ru * IU - r_decay * R,
    D: r_dth * ID,
    # Helper states (usually important for some kind of output)
    DC: p_d * (r_d * IV  + r_d * (1 - p_dth) * I) + r_d * p_dth * I
    }

DDE = jitcdde_input(model_long_covid, n=12, input = input_spline, control_pars=control_pars, max_delay=t_delay, verbose=False )
DDE.compile_C(simplify=True, do_cse=False, chunk_size=30, verbose=True) 

def residuals_totalcases(params) -> float:
    """
    Function that makes sure the parameters are in the right range during the fitting process and computes
    the loss function depending on the optimizer that has been chosen for this run as a global variable
    :param params: currently fitted values of the parameters during the fitting process
    :return: the value of the loss function as a float that is optimized against (in our case, minimized)
    """
    # Variables Initialization for the ODE system
    params =  tuple(params_validation(params))
    x_0_cases = get_initial_conditions(
        params_fitted=params, global_params_fixed=GLOBAL_PARAMS_FIXED
    )
    # Force params values to stay in a certain range during the optimization process with re-initializations
    tic()
    DDE.purge_past()
    print(f"purge past: {toc()}")
    tic()
    DDE.constant_past(x_0_cases)      
    print(f"setting constant_past: {toc()}")           
    tic()
    DDE.set_parameters(params)
    print(f"setting parameters: {toc()}")
    tic()
    DDE.adjust_diff()
    print(f"adjust diff: {toc()}")
    tic()
    x_sol = np.transpose(np.vstack([ DDE.integrate(time) for time in t_cases ]))
    print(f"actual integration time: {toc()}")
    weights = np.ones(len(cases_data_fit))
    residuals_value = get_residuals_value(
        optimizer=OPTIMIZER,
        balance=balance,
        x_sol=x_sol,
        cases_data_fit=cases_data_fit,
        deaths_data_fit=deaths_data_fit,
        weights=weights,
        balance_total_difference=balance_total_difference 
    )
    return residuals_value

Basically, I designed a wrapper function residuals_totalcases to wrap the dde model model_long_covid for downstream optimization purposes (in which I would not bore you here). My question/problem is that as you can see with my liberal use of the tictoc package, I have been trying to figure out what is taking most of the time as the DDE formulation of these equations scale roughly 10x slower than the equivalent set of equations in which there is no delay term (i.e. setting the constant delay to 0). And my timing experiments show the following (timing of each step in seconds):

ddetime

As you can see, apparently the slowest step is the step that sets the parameters which take by far the majority of the time. That is somewhat perplexing to me as I would think setting parameters should not take up most of the time of the integration, so I am wondering if I am doing anything wrong and if there is any way to speed the integration up? Thanks a lot!

Wrzlprmft commented 2 years ago

This is indeed strange and something worth investigating. However, I cannot see a clear cause just by looking at your code and mine, in particular since we are talking about smaller time scales than I am used to optimising for typical applications of JiTCDDE. I am afraid that I can only help you if you provide a running example that I can play with or profile. Alternatively, you can profile your code yourself to see in which line of JiTCDDE the time is predominantly spent.

A very wild guess is that unpacking the parameters helps, i.e. use DDE.set_parameters(*params) instead of DDE.set_parameters(params).

Wrzlprmft commented 2 years ago

Soooo, I transformed your code to something that runs and could reproduce your issue. I then profiled your code to find the reason, namely:

Setting parameters itself is not the problem, but when you do it for the first time, some automatic initialisations are triggered. These would have to happen anyway at some point, so you cannot avoid them. You can verify this by setting parameters again, which is then very fast. The automatic initialisation is comparably slow in your case only because the input spline and the regular past of your equation need to be joined into one spline under the hood. This in turn is slow because it all happens in CHSPy, which is not made for efficiency since I did not expect it to be a relevant bottleneck.

Speeding this up without touching your existing code would require to implement a compiled backend for CHSPy, which is no work I see myself or anybody else doing soon. However, I see the following options to possibly speed up things:

Which of these is most efficient or easiest to implement depends on the exact nature of your input, how often it changes, etc.

Wrzlprmft commented 2 years ago

I made a change which avoid that the joining in question is unnecessarily repeated and thus halves time required for the initialisation in question. You need to use the Github version to use, which you can install with something along the lines of:

pip3 install git+git://github.com/neurophysik/jitcdde