pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.74k stars 2.02k forks source link

BUG: np.linalg.LinAlgError using multivariate_normal kernel of SMC #6786

Open omsai opened 1 year ago

omsai commented 1 year ago

Describe the issue:

Close to the end stages of fitting an ODE model with ABC-SMC, the kernel fails with np.linalg.LinAlgError for a particular model. It's not possible to catch the exception in the model simulation code, which makes me wonder if some check are needed in the kernel? The error doesn't typically happen but these sets of parameters tend to create the crash (edit: made the crash reproducible every time instead of every other time using random_seed=456).

You will need this small 2 KB CSV data file to run the reproducible code example: pantaleo1995-figure1.csv

Reproduceable code example:

"""
Perelson AS, Kirschner DE, De Boer R. Dynamics of HIV infection of CD4+ T
cells. Math Biosci. 1993 Mar;114(1):81-125.
doi: 10.1016/0025-5564(93)90043-a. PMID: 8096155.
"""

import pickle

import numpy as np
import pandas as pd
import pymc as pm
from scipy.integrate import odeint

T_STEADY_STATE = 5  # years, Assumed time to CD4+ steady state.

def read_obs_hiv_cd4_timeseries():
    """Read the patient cd4 time series spreadsheet."""
    df = pd.read_csv("pantaleo1995-figure1.csv")
    df = df.rename({"cd4_cells_per_mm3": "cd4"}, axis='columns')
    df.year += T_STEADY_STATE
    df["day"] = df.year * 365
    # Convert subject strings to float for PyMC's Distribution class that has a
    # convert_observed_data() function that assumes all data should be float.
    df["group"] = df["group"].str.extract(r"subject(\d+)")
    return df

# Sorting by time is necessary for merging columns later on.
DF_OBSERVED = read_obs_hiv_cd4_timeseries().sort_values("day")
# Collapse data into single representative sample.
DF_OBSERVED_1 = DF_OBSERVED[DF_OBSERVED["group"] == "1"]
OBSERVED = DF_OBSERVED_1.cd4.to_numpy()
# Only solve the ODE system at timepoints for which we have data.  T_VALS is in
# days because all the parameter units are in days.
T_VALS = DF_OBSERVED_1.day.to_numpy()
# Initial values.
Y0 = [1e3, 0, 0, 1e-3]

def hiv(y, t, s, mu_V, N):
    """Rates of change RHS (equations 5a-5d, page 87)."""
    # mm^{-3}, Maximum CD4+ cells.
    T_max = 1500
    # day^{-1}, Death rate of uninfected and latently CD4+ cells.
    mu_T = 0.02
    # day^{-1}, Death rate of actively infected CD4+ cells.
    mu_b = 0.24
    # day^{-1}, Rate of growth for the CD4+ cells.
    r = 0.03
    # mm^{3}day^{-1}, Rate constant for CD4+ becoming infected.
    k_1 = 2.4e-5
    # day^{-1}, Rate latently to actively infected conversion.
    k_2 = 3e-3
    return np.array([
        # T
        s - mu_T*y[0] + r*y[0]*(1 - (y[0] + y[1] + y[2])/T_max) -
        k_1*y[3]*y[0],
        # T_li
        k_1*y[3]*y[0] - mu_T*y[1] - k_2*y[1],
        # T_ai
        k_2*y[1] - mu_b*y[2],
        # V
        N*mu_b*y[2] - k_1*y[3]*y[0] - mu_V*y[3],
    ])

def simulate_hiv(rng, s, mu_V, N, size=None):
    # Pad zero value time to match the Y0 initial condition.
    times = np.concatenate((np.zeros((1,)), T_VALS))
    try:
        ret = odeint(hiv, Y0, times, rtol=0.01, mxstep=100,
                     args=(s, mu_V, N))
        # Apply the dimensional reduction here to make the model comparable to
        # the data.  Sum all the T cell counts.
        return np.sum(ret[1:, 0:3], axis=1)
    except np.linalg.LinAlgError:
        return np.empty((len(T_VALS),)).fill(np.nan)

if __name__ == "__main__":
    # Simulate from the model.
    with pm.Model() as model_hiv:
        # Table 1, page 88.
        # day^{-1}mm^{-3}, Rate of supply of CD4+ cells from precursors.
        s = pm.Gamma("s", alpha=1.985656, beta=5.681687)
        # day^{-1}, Death rate of free virus.
        mu_V = pm.Gamma("mu_V", alpha=1.985657, beta=1.363605)
        # Number of free virus produced by lysing a CD4+ cell.
        N = pm.NegativeBinomial("N", n=13.5, p=0.01477833)

        # Instead of specifying a likelihood function, simulate from the
        # model.
        sim = pm.Simulator("sim",
                           simulate_hiv,
                           params=(s, mu_V, N),
                           epsilon=10,
                           observed=OBSERVED)
        # Collect inference data.
        idata_hiv = pm.sample_smc(cores=8, random_seed=456)

    with open("perelson1993-01_output-idata_hiv.pkl", "wb") as file_:
        pickle.dump(idata_hiv, file_)

Error message:

<details>

$ python3 perelson1993pymc_scipy_01_model.py
Initializing SMC sampler...
Sampling 8 chains in 8 jobs
...
multiprocessing.pool.RemoteTraceback: ███████████████████████████████████████████████-----------------| 83.00% [83/100 00:00<?  Stage: 8 Beta: 0.834]
"""
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 359, in _sample_smc_int
    smc.tune()
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/kernels.py", line 389, in tune
    self.proposal_dist = multivariate_normal(mean, cov)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 393, in __call__
    return multivariate_normal_frozen(mean, cov,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 834, in __init__
    self._dist._process_parameters(mean, cov, allow_singular))
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 417, in _process_parameters
    psd = _PSD(cov, allow_singular=allow_singular)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 172, in __init__
    raise np.linalg.LinAlgError(msg)
numpy.linalg.LinAlgError: When `allow_singular is False`, the input matrix must be symmetric positive definite.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/pnanda/immunology/papers/calipro-vs-abc/code/perelson1993pymc_scipy_01_model.py", line 101, in <module>
    idata_hiv = pm.sample_smc(cores=8)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 213, in sample_smc
    results = run_chains_parallel(
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 388, in run_chains_parallel
    results = _starmap_with_kwargs(
              ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 415, in _starmap_with_kwargs
    return pool.starmap(_apply_args_and_kwargs, args_for_starmap)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/pool.py", line 375, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/pool.py", line 774, in get
    raise self._value
  File "/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/pool.py", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
      ^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/sampling.py", line 359, in _sample_smc_int
    smc.tune()
  ^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/pymc/smc/kernels.py", line 389, in tune
    self.proposal_dist = multivariate_normal(mean, cov)
  ^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 393, in __call__
    return multivariate_normal_frozen(mean, cov,
  ^^^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 834, in __init__
    self._dist._process_parameters(mean, cov, allow_singular))
^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 417, in _process_parameters
    psd = _PSD(cov, allow_singular=allow_singular)
^^^^^^^^^^^^^^^
  File "/Users/pnanda/Library/Python/3.11/lib/python/site-packages/scipy/stats/_multivariate.py", line 172, in __init__
    raise np.linalg.LinAlgError(msg)
^^^^^^^^^^^^^^^
numpy.linalg.LinAlgError: When `allow_singular is False`, the input matrix must be symmetric positive definite.

</details>

PyMC version information:

Installed from pip on macOS 13.4 (Apple M2 chip): ```console $ pip list Package Version ------------------------ -------- anyio 3.6.2 appnope 0.1.3 argon2-cffi 21.3.0 argon2-cffi-bindings 21.2.0 arrow 1.2.3 arviz 0.15.1 asttokens 2.2.1 attrs 23.1.0 autopep8 2.0.2 backcall 0.2.0 beautifulsoup4 4.12.2 black 23.3.0 bleach 6.0.0 cachetools 5.3.0 cffi 1.15.1 click 8.1.3 cloudpickle 2.2.1 comm 0.1.3 cons 0.4.5 contourpy 1.0.7 cycler 0.11.0 Cython 3.0.0b3 dask 2023.4.1 debugpy 1.6.7 decorator 5.1.1 defusedxml 0.7.1 distributed 2023.4.1 elfi 0.8.6 entrypoints 0.4 etuples 0.3.8 executing 1.2.0 fastjsonschema 2.16.3 fastprogress 1.0.3 filelock 3.12.0 flake8 6.0.0 fonttools 4.39.3 fqdn 1.5.1 fsspec 2023.5.0 gitdb 4.0.10 GitPython 3.1.31 GPy 1.10.0 h5netcdf 1.1.0 h5py 3.8.0 idna 3.4 importlib-metadata 6.6.0 ipykernel 6.22.0 ipyparallel 8.6.1 ipython 8.12.0 ipython-genutils 0.2.0 ipywidgets 8.0.6 isoduration 20.11.0 jabbar 0.0.15 jedi 0.18.2 Jinja2 3.1.2 joblib 1.2.0 jsonpointer 2.3 jsonschema 4.17.3 jupyter 1.0.0 jupyter_client 8.2.0 jupyter-console 6.6.3 jupyter_core 5.3.0 jupyter-events 0.6.3 jupyter_server 2.5.0 jupyter_server_terminals 0.4.4 jupyterlab-pygments 0.2.2 jupyterlab-widgets 3.0.7 kiwisolver 1.4.4 locket 1.0.0 logical-unification 0.4.5 MarkupSafe 2.1.2 matplotlib 3.7.1 matplotlib-inline 0.1.6 mccabe 0.7.0 miniKanren 1.0.3 mistune 2.0.5 mizani 0.9.0 msgpack 1.0.5 multipledispatch 0.6.0 mypy-extensions 1.0.0 nbclassic 0.5.5 nbclient 0.7.3 nbconvert 7.3.1 nbformat 5.8.0 nbstripout 0.6.1 nest-asyncio 1.5.6 networkx 3.1 notebook 6.5.4 notebook_shim 0.2.3 numdifftools 0.9.41 numpy 1.24.3 packaging 23.1 pandas 2.0.1 pandocfilters 1.5.0 paramz 0.9.5 parso 0.8.3 partd 1.4.0 pathspec 0.11.1 patsy 0.5.3 pexpect 4.8.0 pickleshare 0.7.5 Pillow 9.5.0 pip 23.1.2 platformdirs 3.2.0 plotnine 0.10.1 prometheus-client 0.16.0 prompt-toolkit 3.0.38 psutil 5.9.5 ptyprocess 0.7.0 pure-eval 0.2.2 pyabc 0.12.10 pyarrow 12.0.0 pycodestyle 2.10.0 pycparser 2.21 pyflakes 3.0.1 Pygments 2.15.1 pymc 5.3.0 pyparsing 3.0.9 pyrsistent 0.19.3 pytensor 2.11.1 python-dateutil 2.8.2 python-json-logger 2.0.7 pytz 2023.3 PyYAML 6.0 pyzmq 25.0.2 qtconsole 5.4.2 QtPy 2.3.1 redis 4.5.5 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 scikit-learn 1.2.2 scikits.odes 2.7.0 scipy 1.10.1 seaborn 0.12.2 Send2Trash 1.8.0 setuptools 67.6.1 six 1.16.0 smmap 5.0.0 sniffio 1.3.0 sortedcontainers 2.4.0 soupsieve 2.4.1 SQLAlchemy 2.0.12 stack-data 0.6.2 statsmodels 0.13.5 tblib 1.7.0 terminado 0.17.1 threadpoolctl 3.1.0 tinycss2 1.2.1 tomli 2.0.1 toolz 0.12.0 tornado 6.3.1 tqdm 4.65.0 traitlets 5.9.0 typing_extensions 4.5.0 tzdata 2023.3 uri-template 1.2.0 urllib3 2.0.2 wcwidth 0.2.6 webcolors 1.13 webencodings 0.5.1 websocket-client 1.5.1 wheel 0.40.0 widgetsnbextension 4.0.7 xarray 2023.4.2 xarray-einstats 0.5.1 yapf 0.33.0 zict 3.0.0 zipp 3.15.0 ```

Context for the issue:

Causes a crash with these particular model parameters.

welcome[bot] commented 1 year ago

Welcome Banner :tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

ricardoV94 commented 1 year ago

CC @aloctavodia

ricardoV94 commented 1 year ago

Not sure what the solution should be in this case, when the mvnornal kernel is unstable...

In the meantime you can perhaps try using a different kernel that doesn't rely on a multivariate sampling of the particles and see if that works for you.

omsai commented 1 year ago

That worked! After switching to the Metropolis-Hastings kernel instead of Independent Metropolis-Hastings, I get no longer get that crash after trying 20 times:

$ for i in {1..20}; do python3 -Wignore perelson1993pymc_scipy_01_model.py echo $i: $? done
Initializing SMC sampler...
Sampling 8 chains in 8 jobs
       The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
1: 0
...[repeats 20x]...

$ diff -u perelson1993pymc_scipy_01_model.py{.orig,}
--- perelson1993pymc_scipy_01_model.py.orig 2023-06-21 17:45:21
+++ perelson1993pymc_scipy_01_model.py  2023-06-21 17:32:25
@@ -9,6 +9,7 @@
 import numpy as np
 import pandas as pd
 import pymc as pm
+from pymc.smc.kernels import MH
 from scipy.integrate import odeint

@@ -98,7 +99,7 @@
                            epsilon=10,
                            observed=OBSERVED)
         # Collect inference data.
-        idata_hiv = pm.sample_smc(cores=8, random_seed=456)
+        idata_hiv = pm.sample_smc(cores=8, kernel=MH)

     with open("perelson1993-01_output-idata_hiv.pkl", "wb") as file_:
         pickle.dump(idata_hiv, file_)

NB: I edited the original program above with a random_seed to make it crash every time.