pyomeca / cocofest

Cocofest is an optimal control python package for functional electrical stimulation models
MIT License
2 stars 1 forks source link

Multiple Ding model creation and number of past stimulation #2

Open Kev1CO opened 2 months ago

Kev1CO commented 2 months ago

Subject of the enhancement

This is a refactor that can be done.

Proposed enhancement

model = DingModel(n_stim_prev: int, ....) model.nb_stim_prev --> val MultiDing(....).to_models --> [model1, model2,...]

Kev1CO commented 2 months ago

A proposal

model = DingModel(time_stim_prev=[0,0.1,0.2], time_current_stim=[0.3])
print(model.nb_stim_prev)  # returns 3
# if symbolic
model.set_stim_prev(MX(t, 3, 1))
model.set_current_stim(MX(another_t, 1, 1))
# DingBuilder
ding_builder = DingBuilder(stim_time=[0,0.1,0.2,0.3])
print(ding_builder.build()) # returns the list of ding models for each phases.
# the stim setter should be done DingBuilder because of the pointing properties of python
ding_builder.set_stim_times(MX(t, 4, 1))
# so all descendents would follow I believe !

Full prototype :

DingModel: This class manages the stimulation times and allows setting these times with symbolic values if required. DingBuilder: This class manages multiple instances of DingModel and sets the stimulation times for all instances. Example Implementation

from casadi import MX

class DingModel:
    def __init__(self, time_stim_prev=None, time_current_stim=None):
        self.time_stim_prev = time_stim_prev if time_stim_prev is not None else []
        self.time_current_stim = time_current_stim if time_current_stim is not None else []

    @property
    def nb_stim_prev(self):
        return len(self.time_stim_prev)

    def set_stim_prev(self, stim):
        self.time_stim_prev = stim

    def set_current_stim(self, stim):
        self.time_current_stim = stim

class DingBuilder:
    def __init__(self, stim_time=None):
        self.stim_time = stim_time if stim_time is not None else []
        self.models = self._create_models()

    def _create_models(self):
        # Create DingModel instances for each phase based on stim_time
        phases = [self.stim_time[i:i+1] for i in range(len(self.stim_time))]
        return [DingModel(time_stim_prev=phases[i], time_current_stim=phases[i+1] if i+1 < len(phases) else []) for i in range(len(phases))]

    def build(self):
        return self.models

    def set_stim_times(self, stim):
        self.stim_time = stim
        # Update all models with the new stimulation times
        self.models = self._create_models()

# DingBuilder
ding_builder = DingBuilder(stim_time=[0, 0.1, 0.2, 0.3])
models = ding_builder.build()

# Set symbolic stim times
ding_builder.set_stim_times(MX.sym('t', 4, 1))

Explanation:

DingModel Class:

init: Initializes with previous and current stimulation times. nb_stim_prev: Property to return the number of previous stimulation times. set_stim_prev and set_current_stim: Methods to set the previous and current stimulation times, supporting symbolic values with CasADi's MX. DingBuilder Class:

init: Initializes with a list of stimulation times and creates models for each phase. _create_models: Internal method to create DingModel instances for each phase based on the stimulation times. build: Returns the list of DingModel instances. set_stim_times: Updates the stimulation times and recreates the models.

Usage:

DingModel: Create an instance, set stimulation times, and get the number of previous stimulations. DingBuilder: Create an instance with stimulation times, build models, and update stimulation times symbolically. This structure ensures that updating the stimulation times in the DingBuilder will propagate changes to all DingModel instances.

Optimal Control, setting the parameters when variable time.


parameters.add(
      name="pulse_apparition_time",
      function=ding_builder.set_pulse_apparition_time,  # the secret sauce !
      size=n_stim,
      scaling=VariableScaling("pulse_apparition_time", [1] * n_stim),
  )
parameters_bounds.add(
    "pulse_apparition_time",
    min_bound=np.array(time_min_list),
    max_bound=np.array(time_max_list),
    interpolation=InterpolationType.CONSTANT,
)

parameters_init["pulse_apparition_time"] = (np.array(time_min_list) + np.array(time_max_list))/2
Kev1CO commented 2 months ago

maybe it's easier to have:

model.set_all_pulse_apparition_time([.....])