patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 129 forks source link

Problem when using Diffrax for numpyro #218

Open MaAl13 opened 1 year ago

MaAl13 commented 1 year ago

Hello , i want to use Diffrax for Bayesian inference of parameters in numpyro. However, as soon as i change the StepsizeControler from ConstantStepsie to DPIController i get an error. Changing the max_steps to really big numbers and also using an implicit solver doesn't help. Can you maybe tell me what the problem is? The code is the following

import sys
# from pathlib import Path
from jax.experimental.ode import odeint
# import arviz as az
import dill
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from diffrax import PIDController, Dopri5, ODETerm, SaveAt, diffeqsolve, Kvaerno5
from jax import random
from numpyro.infer import MCMC, NUTS  # , Predictive
from sklearn.preprocessing import LabelEncoder

# from fem_cycle_model.fetch_params import fetch_params
# from fem_cycle_model.main import run_model
numpyro.enable_x64()
pd.options.mode.chained_assignment = None  # default='warn'
assert numpyro.__version__.startswith("0.11.0")

# Select the number of cores that numpyro will use
numpyro.set_host_device_count(1)

def test_ode(t, z, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = (
        theta[0],
        theta[1],
        theta[2],
        theta[3],
    )
    #print(theta)
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])

# Create a linspace of time points 0, 0.2, 0.4, ...

def sim(params):

    term = ODETerm(test_ode)
    solver = Dopri5()
    saveat = SaveAt(ts=np.linspace(0, 10, 50))
    stepsize_controller = PIDController(rtol = 1e-3, atol = 1e-6, step_ts = np.linspace(0, 10, 50))
    sol = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=10,
        dt0=0.1,
        y0=jnp.array([1, 0.2]),
        saveat=saveat,
        max_steps = 20000,
        stepsize_controller=stepsize_controller,
        args=params,
    )

    #sol = odeint(test_ode, jnp.array([1,0.2]), jnp.linspace(0, 10, 50), jnp.array(params), rtol=1e-6, atol=1e-5, mxstep=1000)
    return sol.ys[:, 0], sol.ys[:, 1]

def run_model_data(num_patients, params):
    y1_total = jnp.array([])
    y2_total = jnp.array([])
    for i in range(num_patients):
        params_i = jnp.array(params) * (i + 1)
        y1, y2 = sim(params_i)
        y1_total = jnp.append(y1_total, y1)
        y2_total = jnp.append(y2_total, y2)
    return y1_total, y2_total

patient_y1, patient_y2 = run_model_data(1, [2, 3, 4, 5])

def run_model_all(params):
    y1, y2 = sim(params)
    return y1, y2

def model(y1=None, y2=None):
    """
    μ_foll_alpha = numpyro.sample("μ_foll_alpha", dist.Uniform(1.0, 7.0))
    σ_foll_alpha = numpyro.sample("σ_foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(3.0, 2.0)))

    μ_foll_beta = numpyro.sample("μ_foll_beta", dist.Uniform(2.0, 10.0))
    σ_foll_beta = numpyro.sample("σ_foll_beta", dist.LeftTruncatedDistribution(dist.Normal(4.0, 2.0)))

    μ_foll_gamma = numpyro.sample("μ_foll_gamma", dist.Uniform(3.0, 13.0))
    σ_foll_gamma = numpyro.sample("σ_foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(5.0, 2.0)))

    μ_foll_delta = numpyro.sample("μ_foll_delta", dist.Uniform(4.0, 16.0))
    σ_foll_delta = numpyro.sample("σ_foll_delta", dist.LeftTruncatedDistribution(dist.Normal(7.0, 2.0)))

    σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))

    #with numpyro.plate("plate_i", N_PATIENTS):
    foll_alpha = numpyro.sample("foll_alpha", dist.Normal(μ_foll_alpha, σ_foll_alpha))
    foll_beta = numpyro.sample("foll_beta", dist.Normal(μ_foll_beta, σ_foll_beta))
    foll_gamma = numpyro.sample("foll_gamma", dist.Normal(μ_foll_gamma, σ_foll_gamma))
    foll_delta = numpyro.sample("foll_delta", dist.Normal(μ_foll_delta, σ_foll_delta))
    """

    σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))
    foll_alpha = numpyro.sample("foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(2, 1)))
    foll_beta = numpyro.sample("foll_beta", dist.LeftTruncatedDistribution(dist.Normal(3, 1)))
    foll_gamma = numpyro.sample("foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(4, 1)))
    foll_delta = numpyro.sample("foll_delta", dist.LeftTruncatedDistribution(dist.Normal(5, 1)))

    y1_est, y2_est= run_model_all(
        [
            foll_alpha,
            foll_beta,
            foll_gamma,
            foll_delta,
        ]
    )
    with numpyro.plate("likelihood", len(y1)):
        numpyro.sample("obs_dom", dist.Normal(y1_est, σ), obs=y1)
        numpyro.sample("obs_non_dom", dist.Normal(y2_est, σ), obs=y2)

data_dict = dict(
    y1=patient_y1,
    y2 = patient_y2,
)

# Specify the number of chains in the Markov Chain Monte Carlo. Typically set to the nmber of cores in the computer
mcmc_kwargs = dict(num_samples=2000, num_warmup=2000, num_chains=4)

# Select a random key and split it into different parts. This guarantees that we get the same result each time and
# will lead to reproducable results. For more see:
# https://ericmjl.github.io/dl-workshop/02-jax-idioms/03-deterministic-randomness.html
rng_key = random.PRNGKey(12)
seed1, seed2, seed3, seed4, seed5 = random.split(rng_key, 5)

inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample(), dense_mass=True), **mcmc_kwargs)
inference_mcmc.run(seed1, **data_dict)
print(inference_mcmc.print_summary())

The error is

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py:150: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using numpyro.set_host_device_count(4) at the beginning of your program. You can double-check how many devices are available in your system using jax.local_device_count(). inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample(), dense_mass=True), mcmc_kwargs) warmup: 0%| | 1/4000 [00:19<21:38:58, 19.49s/it, 1 steps of size 2.34e+00. acc. prob=0.00] Traceback (most recent call last): File "/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py", line 151, in inference_mcmc.run(seed1, data_dict) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 598, in run states, last_state = _laxmap(partial_map_fn, map_args) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 160, in _laxmap ys.append(f(x)) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc collect_vals = fori_collect( File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py", line 358, in fori_collect vals = jit(_body_fn)(i, vals) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 565, in cache_miss out_flat = call_bind_continuation(execute(args_flat)) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 2113, in call out_bufs = self.xla_executable.execute_sharded_on_local_devices( jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing max_steps.

At: /home/malmansto/anaconda3/lib/python3.9/site-packages/equinox/internal/errors.py(17): raises /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(142): _flat_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(42): pure_callback_impl /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(105): _callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/mlir.py(1798): _wrapped_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py(2113): call /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py(565): cache_miss /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py(358): fori_collect /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(404): _single_chain_mcmc /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(160): _laxmap /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(598): run /home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py(151):

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py", line 151, in inference_mcmc.run(seed1, **data_dict) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 598, in run states, last_state = _laxmap(partial_map_fn, map_args) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 160, in _laxmap ys.append(f(x)) File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc collect_vals = fori_collect( File "/home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py", line 358, in fori_collect vals = jit(_body_fn)(i, vals) jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: The maximum number of solver steps was reached. Try increasing max_steps.

At: /home/malmansto/anaconda3/lib/python3.9/site-packages/equinox/internal/errors.py(17): raises /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(142): _flat_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(42): pure_callback_impl /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/callback.py(105): _callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/mlir.py(1798): _wrapped_callback /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py(2113): call /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/api.py(565): cache_miss /home/malmansto/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py(162): reraise_with_filtered_traceback /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/util.py(358): fori_collect /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(404): _single_chain_mcmc /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(160): _laxmap /home/malmansto/anaconda3/lib/python3.9/site-packages/numpyro/infer/mcmc.py(598): run /home/malmansto/IVF/Test_Mini_Hierarchichal_Model_2.py(151):

patrick-kidger commented 1 year ago

If I had to guess what's going on here: the parameters are being suggested with the wrong sign, so that the Lotka-Volterra equations blow up in finite time.

I've not tried running you example as it is a little large. See if you can try reducing it to a MWE, in particular without numpyro. You should be able to simply check what parameters numpyro is suggesting, and then evaluate Diffrax on those parameters directly.

General debugging tips for this kind of thing by the way: