Open MaAl13 opened 4 months ago
I think you will need to be a bit more specific what you mean by "not working".
I get the following error running the above code:
Traceback (most recent call last):
File "
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
Traceback (most recent call last):
File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 56, in
Do you get the same error with standard optimisers in pypesto? My guess that this is about serialisation of the model, where saccess relies on pickle
/deepcopy
, which doesn't play nicely with jax.
Thanks for getting back so fast! :)
import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing
# Lotka-Volterra model
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
return jnp.stack([d_prey, d_predator])
def solve(parameters, y0, ts):
term = dfx.ODETerm(vector_field)
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=ts)
sol = dfx.diffeqsolve(
term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
adjoint=dfx.RecursiveCheckpointAdjoint(),
)
return sol.ys
# Generate synthetic data
def get_data():
y0 = jnp.array([9.0, 9.0])
true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
ts = jnp.linspace(0, 30, 20)
values = solve(true_parameters, y0, ts)
return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)
y0, ts, noisy_values = get_data()
# Define objective function
@jax.jit
def objective(parameters):
pred_values = solve(parameters, y0, ts)
return jnp.sum((noisy_values - pred_values)**2)
#objective_with_grad = jax.value_and_grad(objective)
objective = pypesto.Objective(
fun=objective,
grad=jax.grad(objective)
)
problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
optimizer = optimize.ScipyOptimizer()
engine = pypesto.engine.SingleCoreEngine()
n_starts = 20
result = optimize.minimize(
problem=problem1, optimizer=optimizer, n_starts=n_starts, engine=engine
)
print(result.summary())
It now runs through, but fails badly for the line search and doesn't get away from the initial guesses
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s]
number of starts: 20
best value: 9694.908203125, id=6
worst value: inf, id=19
number of non-finite values: 19
execution time summary:
summary of optimizer messages:
Count | Message |
---|---|
20 | ABNORMAL_TERMINATION_IN_LNSRCH |
best value found (approximately) 1 time(s)
number of plateaus found: 0
A summary of the best run:
Hi @MaAl13, when using multiprocessing (directly or indirectly as here through SacessOptimizer
), always protect your module-level code with if __name__ == '__main__':
as suggested in the error message above.
I.e., in your case:
import numpy as np
import jax
import time
import jax.numpy as jnp
import diffrax as dfx
import equinox as eqx
import pypesto
import pypesto.optimize as optimize
from pypesto.optimize import ScipyOptimizer
import multiprocessing
# Lotka-Volterra model
def vector_field(t, y, args):
prey, predator = y
α, β, γ, δ = args
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
return jnp.stack([d_prey, d_predator])
def solve(parameters, y0, ts):
term = dfx.ODETerm(vector_field)
solver = dfx.Tsit5()
saveat = dfx.SaveAt(ts=ts)
sol = dfx.diffeqsolve(
term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat,
adjoint=dfx.RecursiveCheckpointAdjoint(),
)
return sol.ys
# Generate synthetic data
def get_data():
y0 = jnp.array([9.0, 9.0])
true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02])
ts = jnp.linspace(0, 30, 20)
values = solve(true_parameters, y0, ts)
return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape)
def main():
y0, ts, noisy_values = get_data()
# Define objective function
@jax.jit
def objective(parameters):
pred_values = solve(parameters, y0, ts)
return jnp.sum((noisy_values - pred_values)**2)
#objective_with_grad = jax.value_and_grad(objective)
objective = pypesto.Objective(
fun=objective,
grad=jax.grad(objective)
)
problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10)
default_ess_options = pypesto.optimize.get_default_ess_options(8, 4, local_optimizer=ScipyOptimizer(method='trust-constr'))
optimizer = pypesto.optimize.SacessOptimizer(ess_init_args = default_ess_options, max_walltime_s=600)
result_custom_problem = optimizer.minimize(problem=problem1)
if __name__ == '__main__':
main()
Alternatively, SacessOptimizer(..., mp_start_method="fork")
might solve this specific issue, but might introduce other problems.
Hi @dweindl, i tried running your code but still get the following error:
Traceback (most recent call last):
File "/Users/malmansto/Documents/IVF_hormone_sim/code/test_ESS.py", line 60, in
That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there.
Thanks for getting back so fast! :)
import numpy as np import jax import time import jax.numpy as jnp import diffrax as dfx import equinox as eqx import pypesto import pypesto.optimize as optimize from pypesto.optimize import ScipyOptimizer import multiprocessing # Lotka-Volterra model def vector_field(t, y, args): prey, predator = y α, β, γ, δ = args d_prey = α * prey - β * prey * predator d_predator = -γ * predator + δ * prey * predator return jnp.stack([d_prey, d_predator]) def solve(parameters, y0, ts): term = dfx.ODETerm(vector_field) solver = dfx.Tsit5() saveat = dfx.SaveAt(ts=ts) sol = dfx.diffeqsolve( term, solver, t0=ts[0], t1=ts[-1], dt0=0.1, y0=y0, args=parameters, saveat=saveat, adjoint=dfx.RecursiveCheckpointAdjoint(), ) return sol.ys # Generate synthetic data def get_data(): y0 = jnp.array([9.0, 9.0]) true_parameters = jnp.array([0.1, 0.02, 0.4, 0.02]) ts = jnp.linspace(0, 30, 20) values = solve(true_parameters, y0, ts) return y0, ts, values + 0.1 * jax.random.normal(jax.random.PRNGKey(0), values.shape) y0, ts, noisy_values = get_data() # Define objective function @jax.jit def objective(parameters): pred_values = solve(parameters, y0, ts) return jnp.sum((noisy_values - pred_values)**2) #objective_with_grad = jax.value_and_grad(objective) objective = pypesto.Objective( fun=objective, grad=jax.grad(objective) ) problem1 = pypesto.Problem(objective=objective, lb=np.zeros((4, 1)), ub=np.ones((4, 1))*10) optimizer = optimize.ScipyOptimizer() engine = pypesto.engine.SingleCoreEngine() n_starts = 20 result = optimize.minimize( problem=problem1, optimizer=optimizer, n_starts=n_starts, engine=engine ) print(result.summary())
It now runs through, but fails badly for the line search and doesn't get away from the initial guesses
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00, 8.25it/s]
Optimization Result
- number of starts: 20
- best value: 9694.908203125, id=6
- worst value: inf, id=19
- number of non-finite values: 19
- execution time summary:
- Mean execution time: 0.121s
- Maximum execution time: 1.605s, id=0
- Minimum execution time: 0.012s, id=6
summary of optimizer messages:
Count Message
20 ABNORMAL_TERMINATION_IN_LNSRCH
- best value found (approximately) 1 time(s)
- number of plateaus found: 0
A summary of the best run:
Optimizer Result
- optimizer used: <ScipyOptimizer method=L-BFGS-B options={'disp': False, 'maxfun': 1000}>
- message: ABNORMAL_TERMINATION_IN_LNSRCH
- number of evaluations: 2
- time taken to optimize: 0.012s
- startpoint: [2.46408138 2.725371 6.14559588 1.76943688]
- endpoint: [2.46408138 2.725371 6.14559588 1.76943688]
- final objective value: 9694.908203125
- final gradient value: [ 8.141042e+07 7.118934e+08 9.747198e+07 -7.434768e+08]
Hard to guess what the issue here is. This looks like incorrect gradients, but likely goes beyond what we can help with in the context of a issue on github.
That's now the problem @FFroehlich was referring to. I am not sufficiently familiar with Jax to help there.
correct, equinox
provide some guidance on serialisation, but I don't know complex it is to get that running with multiprocessing in sacess
.
Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it.
Okay, so you guys would not recommend then using scatter search when doing parameter estimation of ODEs? Is there anything else that you can recommend in pypesto that is compatible with jax and has nice gloabl properties? Since with diffrax it is easy to get the gradient i think it would be a shame not to use it.
I wouldn’t go that far. In both cases, the issues you encountered should be salvageable. However, since they aren’t ‘bugs’ per se and will require some effort to resolve, providing ready-made solutions is beyond the support we can offer. That said, we always welcome contributions and are happy to provide guidance.
Hello, i want to use your package in order to do parameter estimation of ODEs and later on compute confidence intervals with profile likelihood. However to me the following code is not working on a toy example. Can you maybe tell me what i am doing wrong? I want to use the scatter search since it has been shown better convergence properties tha purely local or global methods.