Open MaAl13 opened 1 year ago
If I had to guess what's going on here: the parameters are being suggested with the wrong sign, so that the Lotka-Volterra equations blow up in finite time.
I've not tried running you example as it is a little large. See if you can try reducing it to a MWE, in particular without numpyro. You should be able to simply check what parameters numpyro is suggesting, and then evaluate Diffrax on those parameters directly.
General debugging tips for this kind of thing by the way:
jax.debug.{print, breakpoint}
inside your vector field, to see what times and values it is being evaluated at. Do the times converge towards a single time (i.e. the point at which an equation blows up in finite time)? So the values explode towards very large numbers, very small numbers, or inf
, or -inf
, or nan
?diffeqsolve(..., throw=False, saveat=SaveAt(steps=True))
and see what times it is being evaluated at. You will get the times and values at which the solver placed steps. These are stored in arrays of length max_steps
: the first part of this array will be the times you evaluate at; after that it will be padded with inf
.
Hello , i want to use Diffrax for Bayesian inference of parameters in numpyro. However, as soon as i change the StepsizeControler from ConstantStepsie to DPIController i get an error. Changing the max_steps to really big numbers and also using an implicit solver doesn't help. Can you maybe tell me what the problem is? The code is the following
The error is
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py:150: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using
inference_mcmc.run(seed1, data_dict)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 598, in run
states, last_state = _laxmap(partial_map_fn, map_args)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 160, in _laxmap
ys.append(f(x))
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc
collect_vals = fori_collect(
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py", line 358, in fori_collect
vals = jit(_body_fn)(i, vals)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, kwargs)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 565, in cache_miss
out_flat = call_bind_continuation(execute(args_flat))
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(args, kwargs)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 2113, in call
out_bufs = self.xla_executable.execute_sharded_on_local_devices(
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing
numpyro.set_host_device_count(4)
at the beginning of your program. You can double-check how many devices are available in your system usingjax.local_device_count()
. inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample(), dense_mass=True), mcmc_kwargs) warmup: 0%| | 1/4000 [00:19<21:38:58, 19.49s/it, 1 steps of size 2.34e+00. acc. prob=0.00] Traceback (most recent call last): File "/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py", line 151, inmax_steps
.At: /home/malmansto/anaconda3/lib/python3.9/site-packages/equinox/internal/errors.py(17): raises /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(142): _flat_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(42): pure_callback_impl /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(105): _callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/mlir.py(1798): _wrapped_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py(2113): call /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py(565): cache_miss /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py(358): fori_collect /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(404): _single_chain_mcmc /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(160): _laxmap /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(598): run /home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py(151):
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py", line 151, in
inference_mcmc.run(seed1, **data_dict)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 598, in run
states, last_state = _laxmap(partial_map_fn, map_args)
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 160, in _laxmap
ys.append(f(x))
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc
collect_vals = fori_collect(
File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py", line 358, in fori_collect
vals = jit(_body_fn)(i, vals)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing
max_steps
.At: /home/malmansto/anaconda3/lib/python3.9/site-packages/equinox/internal/errors.py(17): raises /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(142): _flat_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(42): pure_callback_impl /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(105): _callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/mlir.py(1798): _wrapped_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py(2113): call /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py(565): cache_miss /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py(358): fori_collect /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(404): _single_chain_mcmc /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(160): _laxmap /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(598): run /home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py(151):