ICB-DCM / pyPESTO

python Parameter EStimation TOolbox
https://pypesto.readthedocs.io
BSD 3-Clause "New" or "Revised" License
216 stars 47 forks source link

Using PyPesto with jax #1428

Open MaAl13 opened 1 month ago

MaAl13 commented 1 month ago

Hello, i want to use your package in order to do parameter estimation of ODEs and later on compute confidence intervals with profile likelihood. However to me the following code is not working on a toy example. Can you maybe tell me what i am doing wrong? I want to use the scatter search since it has been shown better convergence properties tha purely local or global methods.

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
result_custom_problem = optimizer.minimize(problem=problem1)
FFroehlich commented 1 month ago

I think you will need to be a bit more specific what you mean by "not working".

MaAl13 commented 1 month ago

I get the following error running the above code:

Traceback (most recent call last): File "", line 1, in File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main exitcode = _main(fd, parent_sentinel) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 125, in _main prepare(preparation_data) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 236, in prepare _fixup_main_from_path(data['init_main_from_path']) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 287, in _fixup_main_from_path main_content = runpy.run_path(main_path, File "/opt/anaconda3/envs/IVF_env/lib/python3.10/runpy.py", line 289, in run_path return _run_module_code(code, init_globals, run_name, File "/opt/anaconda3/envs/IVF_env/lib/python3.10/runpy.py", line 96, in _run_module_code _run_code(code, mod_globals, init_globals, File "/opt/anaconda3/envs/IVF_env/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 56, in result_custom_problem = optimizer.minimize(problem=problem1) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/site-packages/pypesto/optimize/ess/sacess.py", line 210, in minimize with self.mp_ctx.Manager() as shmem_manager: File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 57, in Manager m.start() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/managers.py", line 562, in start self._process.start() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/process.py", line 121, in start self._popen = self._Popen(self) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 288, in _Popen return Popen(process_obj) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in init super().init(process_obj) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_fork.py", line 19, in init self._launch(process_obj) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 42, in _launch prep_data = spawn.get_preparation_data(process_obj._name) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 154, in get_preparation_data _check_not_importing_main() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/spawn.py", line 134, in _check_not_importing_main raise RuntimeError(''' RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

Traceback (most recent call last): File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 56, in result_custom_problem = optimizer.minimize(problem=problem1) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/site-packages/pypesto/optimize/ess/sacess.py", line 210, in minimize with self.mp_ctx.Manager() as shmem_manager: File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 57, in Manager m.start() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/managers.py", line 566, in start self._address = reader.recv() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/connection.py", line 250, in recv buf = self._recv_bytes() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes buf = self._recv(4) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/connection.py", line 383, in _recv raise EOFError EOFError

FFroehlich commented 1 month ago

Do you get the same error with standard optimisers in pypesto? My guess that this is about serialisation of the model, where saccess relies on pickle/deepcopy, which doesn't play nicely with jax.

MaAl13 commented 1 month ago

Thanks for getting back so fast! :)

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
optimizer = optimize.ScipyOptimizer()
engine = pypesto.engine.SingleCoreEngine()
n_starts = 20
result = optimize.minimize(
    problem=problem1, optimizer=optimizer, n_starts=n_starts, engine=engine
)

print(result.summary())

It now runs through, but fails badly for the line search and doesn't get away from the initial guesses

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s]

Optimization Result

A summary of the best run:

Optimizer Result

dweindl commented 1 month ago

Hi @MaAl13, when using multiprocessing (directly or indirectly as here through SacessOptimizer), always protect your module-level code with if __name__ == '__main__': as suggested in the error message above.

I.e., in your case:

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

def main(): 
    y0, ts, noisy_values = get_data()

    # Define objective function
    @jax.jit
    def objective(parameters):
        pred_values = solve(parameters, y0, ts)
        return jnp.sum((noisy_values - pred_values)**2)

    #objective_with_grad = jax.value_and_grad(objective)

    objective = pypesto.Objective(
        fun=objective,
        grad=jax.grad(objective)
    )

    problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
    default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
    optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
    result_custom_problem = optimizer.minimize(problem=problem1)

if __name__ == '__main__':
    main()

Alternatively, SacessOptimizer(..., mp_start_method="fork") might solve this specific issue, but might introduce other problems.

MaAl13 commented 1 month ago

Hi @dweindl, i tried running your code but still get the following error:

Traceback (most recent call last): File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 60, in main() File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 57, in main result_custom_problem = optimizer.minimize(problem=problem1) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/site-packages/pypesto/optimize/ess/sacess.py", line 246, in minimize p.start() File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/process.py", line 121, in start self._popen = self._Popen(self) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/context.py", line 288, in _Popen return Popen(process_obj) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in init super().init(process_obj) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_fork.py", line 19, in init self._launch(process_obj) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch reduction.dump(process_obj, fp) File "/opt/anaconda3/envs/IVF_env/lib/python3.10/multiprocessing/reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) AttributeError: Can't pickle local object 'main..objective'

dweindl commented 1 month ago

That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there.

FFroehlich commented 1 month ago

Thanks for getting back so fast! :)

import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing

# Lotka-Volterra model
def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    return jnp.stack([d_prey, d_predator])

def solve(parameters, y0, ts):
    term = dfx.ODETerm(vector_field)
    solver = dfx.Tsit5()
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
        adjoint=dfx.RecursiveCheckpointAdjoint(),
    )
    return sol.ys

# Generate synthetic data
def get_data():
    y0 = jnp.array([9.0, 9.0])
    true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
    ts = jnp.linspace(0, 30, 20)
    values = solve(true_parameters, y0, ts)
    return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)

y0, ts, noisy_values = get_data()

# Define objective function
@jax.jit
def objective(parameters):
    pred_values = solve(parameters, y0, ts)
    return jnp.sum((noisy_values - pred_values)**2)

#objective_with_grad = jax.value_and_grad(objective)

objective = pypesto.Objective(
    fun=objective,
    grad=jax.grad(objective)
)

problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
optimizer = optimize.ScipyOptimizer()
engine = pypesto.engine.SingleCoreEngine()
n_starts = 20
result = optimize.minimize(
    problem=problem1, optimizer=optimizer, n_starts=n_starts, engine=engine
)

print(result.summary())

It now runs through, but fails badly for the line search and doesn't get away from the initial guesses

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s]

Optimization Result

  • number of starts: 20
  • best value: 9694.908203125, id=6
  • worst value: inf, id=19
  • number of non-finite values: 19
  • execution time summary:
    • Mean execution time: 0.121s
    • Maximum execution time: 1.605s, id=0
    • Minimum execution time: 0.012s, id=6
  • summary of optimizer messages:

    Count Message

    20 ABNORMAL_TERMINATION_IN_LNSRCH

  • best value found (approximately) 1 time(s)
  • number of plateaus found: 0

A summary of the best run:

Optimizer Result

  • optimizer used: <ScipyOptimizer method=L-BFGS-B options={'disp': False, 'maxfun': 1000}>
  • message: ABNORMAL_TERMINATION_IN_LNSRCH
  • number of evaluations: 2
  • time taken to optimize: 0.012s
  • startpoint: [2.46408138 2.725371 6.14559588 1.76943688]
  • endpoint: [2.46408138 2.725371 6.14559588 1.76943688]
  • final objective value: 9694.908203125
  • final gradient value: [ 8.141042e+07 7.118934e+08 9.747198e+07 -7.434768e+08]

Hard to guess what the issue here is. This looks like incorrect gradients, but likely goes beyond what we can help with in the context of a issue on github.

FFroehlich commented 1 month ago

That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there.

correct, equinox provide some guidance on serialisation, but I don't know complex it is to get that running with multiprocessing in sacess.

MaAl13 commented 1 month ago

Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it.

FFroehlich commented 1 month ago

Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it.

I wouldn’t go that far. In both cases, the issues you encountered should be salvageable. However, since they aren’t ‘bugs’ per se and will require some effort to resolve, providing ready-made solutions is beyond the support we can offer. That said, we always welcome contributions and are happy to provide guidance.