pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.49k stars 1.97k forks source link

BUG: `pymc.sample_smc` fails with `pymc.CustomDist` #7224

Open EliasRas opened 3 months ago

EliasRas commented 3 months ago

Describe the issue:

pymc.sample_smc raises a NotImplementedError due to a missing logp method if a pymc.CustomDist is used in a model without dist argument. In addition to using dist, switching to pm.Potential works.

Reproduceable code example:

import pymc as pm
import numpy as np

def _logp(value, mu, sigma):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)

    return pm.logp(dist, value)

def _random(mu, sigma, rng, size):
    if rng is None:
        rng = np.random.default_rng()
    sample = rng.normal(loc=mu, scale=sigma, size=size)

    return sample

def _logcdf(value, mu, sigma):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)

    return pm.logcdf(dist, value)

def _dist(mu, sigma, size):
    return pm.Normal.dist(mu, sigma, size=size)

def main():
    data = np.random.default_rng().normal(5, 2, 1000)

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.CustomDist(
            "y",
            mu,
            sigma,
            logp=_logp,
            random=_random,
            logcdf=_logcdf,
            observed=data,
        )
        sample = pm.sample_smc()  # NotImplementedError

    with pm.Model():
        pm.CustomDist(
            "y",
            2,
            10,
            logp=_logp,
            random=_random,
            logcdf=_logcdf,
        )
        sample = pm.sample_smc()  # NotImplementedError

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.CustomDist(
            "y",
            mu,
            sigma,
            dist=_dist,
            observed=data,
        )
        sample = pm.sample_smc()  # Works

    with pm.Model():
        mu = pm.Normal("mu", mu=0, sigma=10)
        sigma = pm.HalfNormal("sigma", sigma=10)
        pm.Potential(
            "y",
            _logp(data, mu, sigma),
        )
        sample = pm.sample_smc()  # Works

if __name__ == "__main__":
    main()

Error message:

```shell multiprocessing.pool.RemoteTraceback: """ Traceback (most recent call last): File "\envs\pymc\Lib\multiprocessing\pool.py", line 125, in worker result = (True, func(*args, **kwds)) ^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\multiprocessing\pool.py", line 51, in starmapstar return list(itertools.starmap(args[0], args[1])) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 421, in _apply_args_and_kwargs return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 344, in _sample_smc_int smc._initialize_kernel() File "\envs\pymc\Lib\site-packages\pymc\smc\kernels.py", line 239, in _initialize_kernel initial_point, [self.model.varlogp], self.variables, shared ^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 832, in varlogp return self.logp(vars=self.free_RVs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\model\core.py", line 717, in logp rv_logps = transformed_conditional_logp( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 612, in transformed_conditional_logp temp_logp_terms = conditional_logp( ^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\logprob\basic.py", line 542, in conditional_logp q_logprob_vars = _logprob( ^^^^^^^^^ File "\envs\pymc\Lib\functools.py", line 909, in wrapper return dispatch(args[0].__class__)(*args, **kw) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\logprob\abstract.py", line 63, in _logprob raise NotImplementedError(f"Logprob method not implemented for {op}") NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False} """ The above exception was the direct cause of the following exception: Traceback (most recent call last): File "issue.py", line 79, in main() File "issue.py", line 44, in main sample = pm.sample_smc() # Exception has occurred: NotImplementedError ^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 213, in sample_smc results = run_chains_parallel( ^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 390, in run_chains_parallel results = _starmap_with_kwargs( ^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\site-packages\pymc\smc\sampling.py", line 417, in _starmap_with_kwargs return pool.starmap(_apply_args_and_kwargs, args_for_starmap) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\multiprocessing\pool.py", line 375, in starmap return self._map_async(func, iterable, starmapstar, chunksize).get() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "\envs\pymc\Lib\multiprocessing\pool.py", line 774, in get raise self._value NotImplementedError: Logprob method not implemented for CustomDist_y_rv{0, (0, 0), floatX, False} ```

PyMC version information:

``` Python 3.11.7 pymc 5.10.0 pytensor 2.18.6 Win 10 Environment set up via conda but updated pymc and pytensor with pip ``` Also fails with these environments [conda_env.txt](https://github.com/pymc-devs/pymc/files/15465575/conda_env.txt) [conda_env_dev.txt](https://github.com/pymc-devs/pymc/files/15465576/conda_env_dev.txt)

Context for the issue:

I'm testing a model which suffers from slow sampling, possibly due to expensive gradient calculations. I tested SMC as a possible solution as suggested on the forums but got this error message.

Using the dist argument could work in most cases, but there's cases when the distributions provided by pymc are not enough. Using pm.Potential could help with sampling but that would in turn make forward sampling less straightforward.

welcome[bot] commented 3 months ago

Welcome Banner] :tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

EliasRas commented 3 months ago

I tested pm.sample_smc(cores=1) and got no error which made me dig a bit deeper. If I understood correctly, the error with multiple processes happens because e.g. logp gets registered only in the main process. Would it be possible to make a initializer for the pool used in pm.smc.sampling.run_chains_parallel which ensures that the methods are registered properly?

ricardoV94 commented 1 month ago

As mentioned in the linked PR, I cannot reproduce the problem locally or on Google Colab, so it may not be a bug but an issue with how it was installed / VSCode: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing

Can you try to test with latest PyMC and directly from the terminal?

ricardoV94 commented 1 month ago

Could also be a Windows-only bug, so perhaps someone with a Windows machine can try to reproduce. We can run the new test on the Windows job if it's not running right now