neurophysik / jitcdde

Just-in-time compilation for delay differential equations
Other
57 stars 14 forks source link

efficient control based on event detection #53

Closed LukeLabrie closed 4 months ago

LukeLabrie commented 5 months ago

I'm trying to implement a form of event-based control, and can't seem to figure out a good way to do it. Say I have a very simple system of the form $\dot{y} = 1$, but if $y(t) > 50$, I want it to exponentially decay with rate $\lambda$, i.e. $y(t) = y(t_1)e^{-(t-t_1)/\lambda}$, where $t_1$ is the time at which the condition $y(t) > 50$ is reached. To summarize:

$\dot{y} = 1, \quad t < t_1$ $\dot{y} = -\frac{y(t_1)}{\lambda}e^{-(t-t_1)/\lambda}, \quad t >= t_1$

where $t_1$ is the time at which $y(t)$ reaches 50. I don't think this can be done alone with jitcdde_input, since those splines require a priori knowledge of $t_1$. So the state needs to be checked at each step.

I know this can be naively implemented with two separate systems as:

# system
f1 = [1.0]
DDE0 = jitcdde(f1)
DDE0.constant_past([0.0])

# generate solution
T = np.arange(0,100,0.01)
sol1 = []
tripped = False
t_trip = 0.0
idx = 0
while not tripped:
    time = T[idx]
    sol1.append(DDE0.integrate(time))
    idx += 1
    temp_current = DDE0.get_state()[-1][1][0]
    if temp_current >= 50.0 and not tripped:
        tripped = True
        t_trip = DDE0.get_state()[-1][0]

# new system with decay term 
decay_rate = 5.0
f2 = [-(temp_current/decay_rate)*sp.exp(-(t-t_trip)/decay_rate)]
DDE1 = jitcdde(f2)
sol2 = []
DDE1.add_past_points(DDE0.get_state())
for time in T[idx:]:
    sol2.append(DDE1.integrate(time))
sol = sol1 + sol2

This is unfortunately very slow for larger systems. The bottleneck appears to be in the get_state() call. From a quick glance at the source code, it looks like this is creating a new object every time it's called, so perhaps it's not really suitable to be called at each integration step. Is there a more direct way to access the state during integration?

Another way I tried to do this was using callback function, i.e. something of the form:

t_trip = 0.0
tripped = False
β = Function('param_jump')
def param_jump_callback(y, t):
    global tripped 
    if (y[0] < 50) and not tripped:
        return 1.0
    elif not tripped:
        global t_trip
        tripped = True 
        t_trip = t
    return -(50/10)*np.exp(-(t-t_trip)/10)

# # system
f = [1*β(t)]
DDE = jitcdde(f,callback_functions=[(β,param_jump_callback,2)])
DDE.constant_past([1.0])
DDE.step_on_discontinuities()

but this leads to seg faults, presumably because these callbacks do not get write access to variables, and therefore I'm unable to extract $t_1$ for future use.

My question is therefore the following: Is it possible to efficiently check, and extract information from the state during integration? If not, could it be implemented somehow?

Let me know if any of this is unclear.

Wrzlprmft commented 5 months ago

Is there a more direct way to access the state during integration?

Yes: DDE.integrate returns the current state of the system. At least for your example, this should suffice to detect when the threshold has been exceeded. Then you can use get_state once to pin-point the time more exactly, truncate, and use it as a new initial condition (see the documentation for CHSPy).

If the output of DDE.integrate doesn’t contain sufficient information for you, there are probably ways to make things faster, but then I would need more details, e.g., whether you need the derivative or part of the past.

Another way I tried to do this was using callback function, … but this leads to seg faults, presumably because these callbacks do not get write access to variables …

At a first glance, this is because the number of arguments is wrong (the first argument is not counted). I tested that you can write to global variables in the callback; so that’s not the problem.

Anyway, using callbacks this way is problematic for the following reason: During the integration, the derivative is also evaluated at states and times that are not part of the final result. This again happens for two reasons:

Each of these evaluations calls your callback function and may cause tripped even though the actual system didn’t trip.

Finally, a cursory sanity check: The absence of any delay term (and thus a reason to make this a DDE) is due to the minimisation of the example, right?

LukeLabrie commented 5 months ago

Yes: DDE.integrate returns the current state of the system. At least for your example, this should suffice to detect when the threshold has been exceeded. Then you can use get_state once to pin-point the time more exactly, truncate, and use it as a new initial condition (see the documentation for CHSPy).

This indeed worked well, thanks. I misunderstood what DDE.get_state() was really doing. I basically thought it was storing the state and the derivatives for each time-step. But it's just storing the minimum amount of 'anchors' to interpolate the past within a given tolerance (this then allows for past states to be accurately interpolated). Is this understanding correct?

Each of these evaluations calls your callback function and may cause tripped even though the actual system didn’t trip.

Ah, makes sense. Thanks for the explanation.

Finally, a cursory sanity check: The absence of any delay term (and thus a reason to make this a DDE) is due to the minimisation of the example, right?

Yep, I am working with delay dynamics. Otherwise I'd just be using scipy.integrate.

Wrzlprmft commented 5 months ago

This indeed worked well, thanks. I misunderstood what DDE.get_state() was really doing. I basically thought it was storing the state and the derivatives for each time-step. But it's just storing the minimum amount of 'anchors' to interpolate the past within a given tolerance (this then allows for past states to be accurately interpolated). Is this understanding correct?

Yes, if by time-step you mean sampling step and not integration step. get_state returns one anchor per integration step, which is what is necessary to interpolate past states as accurate as the integration itself. This is not all states, however, but only those which are needed to interpolate delayed states, i.e., which are closer to the present than the maximum delay. Prior states are automatically discarded.

(This is a possible cause of erratic behaviour if you are using a minimal example without a delay.)