pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.52k stars 1.98k forks source link

BUG: as_op not pickled, making parallel SMC crash #7078

Open jucor opened 7 months ago

jucor commented 7 months ago

Describe the issue:

As it stands, SMC sampler cannot be parallelized with custom ops.

When using SMC sampler with more than one core (i.e. parallel sampling) and an as_op custom op, the op is not pickled properly in the "manual" pickling at https://github.com/pymc-devs/pymc/blob/118be0f23782945dc03c5fb36d58d6ce4a1f619f/pymc/smc/sampling.py#L385 , thus causing the run to fail.

Reproduceable code example:

import pymc as pm
import pytensor.tensor as pt

from pytensor.compile.ops import as_op

@as_op(itypes=[pt.dvector], otypes=[pt.dvector])
def twice(x):
    return 2*x

with pm.Model() as model:
    x = pm.Normal('x', mu=[0, 0], sigma=1)
    y = twice(x)
    z = pm.Normal(name='z', mu=y, observed=[1, 1])

    # Using cores=1 would work, but cores=2 throws an error
    trace = pm.sample_smc(10,cores=2)

Error message:

<details>
{
    "name": "AttributeError",
    "message": "module '__main__' has no attribute 'twice'",
    "stack": "---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
\"\"\"
Traceback (most recent call last):
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 320, in _sample_smc_int
    (draws, kernel, start, model) = map(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/ops.py\", line 221, in load_back
    obj = getattr(module, name)
          ^^^^^^^^^^^^^^^^^^^^^
AttributeError: module '__main__' has no attribute 'twice'
\"\"\"

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
Cell In[14], line 2
      1 with model:
----> 2     trace = pm.sample_smc(10,cores=2)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    210 t1 = time.time()
    212 if cores > 1:
--> 213     results = run_chains_parallel(
    214         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
    215     )
    216 else:
    217     results = run_chains_sequential(
    218         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
    219     )

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
    386 params = tuple(cloudpickle.dumps(p) for p in params)
    387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
    389     pool,
    390     to_run,
    391     [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
    392     repeat(kernel_kwargs),
    393 )
    394 results = tuple(cloudpickle.loads(r) for r in results)
    395 pool.close()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
    411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
    412     # Helper function to allow kwargs with Pool.starmap
    413     # Copied from https://stackoverflow.com/a/53173433/13311693
    414     args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415     return pool.starmap(_apply_args_and_kwargs, args_for_starmap)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
    369 def starmap(self, func, iterable, chunksize=None):
    370     '''
    371     Like `map()` method but the elements of the `iterable` are expected to
    372     be iterables as well and will be unpacked as arguments. Hence
    373     `func` and (a, b) becomes func(a, b).
    374     '''
--> 375     return self._map_async(func, iterable, starmapstar, chunksize).get()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

AttributeError: module '__main__' has no attribute 'twice'"
}
</details>

PyMC version information:

pymc: 5.10.3 pytensor: 2.18.4 python: 3.11.7

Installed in a fresh conda environment with conda create -c conda-forge -n pymc_env "pymc>=5"

Context for the issue:

As it stands, SMC sampler cannot run the official PyMC example from https://www.pymc.io/projects/examples/en/latest/ode_models/ODE_Lotka_Volterra_multiple_ways.html Any simple ODE where sunode is overkill will crash similarly, as it requires a custom op, that is not pickled.

The workaround of using a single core makes the method much slower than needed.

Is there a way to serialize the custom operation please?

welcome[bot] commented 7 months ago

Welcome Banner :tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

jucor commented 7 months ago

@aloctavodia Given all your work on SMC and its parallelization (in particular https://github.com/pymc-devs/pymc/pull/3981), would you have any idea what's going on, please, and how to add those ops to what's being pickled, please? Thanks a lot for any idea :)

ricardoV94 commented 7 months ago

May want to try and define the Op in a python script instead of at runtime

jucor commented 7 months ago

Thanks for the idea! I'll try this workaround and report here. Not ideal on the long-term (single-file example being very handy for what I'm trying to achieve pedagogically) but I'll definitely take it if it works :-)

On Thu, Dec 28, 2023, 14:01 Ricardo Vieira @.***> wrote:

May want to try and define the Op in a python script instead of at runtime

— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pymc/issues/7078#issuecomment-1871156840, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAFBEROZGWK7D7TXTL2RBUTYLVUSFAVCNFSM6AAAAABBE3UKOGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNZRGE2TMOBUGA . You are receiving this because you authored the thread.Message ID: @.***>

jucor commented 7 months ago

@ricardoV94 That workaround works!! 🎉 Awesome, that'll be perfect until a longer-time fix works :)

Now the smc sampler is hitting another issue lower down, which could also be related to pickling but seems linked to the progress bar, complaining about HTML not existing. Any idea if an extra import somewhere could help?

{
    "name": "NameError",
    "message": "name 'HTML' is not defined",
    "stack": "---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
\"\"\"
Traceback (most recent call last):
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py\", line 51, in starmapstar
    return list(itertools.starmap(args[0], args[1]))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 419, in _apply_args_and_kwargs
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py\", line 342, in _sample_smc_int
    progressbar.update_bar(getattr(progressbar, \"offset\", 0) + 0)
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/fastprogress/fastprogress.py\", line 81, in update_bar
    self.on_update(val, f'{pct}[{val}/{tot} {elapsed_t}{self.lt}{remaining_t}{end}]')
  File \"/opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/fastprogress/fastprogress.py\", line 133, in on_update
    if self.display: self.out.update(HTML(self.progress))
                                     ^^^^
NameError: name 'HTML' is not defined
\"\"\"

The above exception was the direct cause of the following exception:

NameError                                 Traceback (most recent call last)
Cell In[421], line 4
      2 draws = 2000
      3 with model:
----> 4     trace_SMC_like = pm.sample_smc(draws,cores=5)
      5 trace = trace_SMC_like
      6 az.summary(trace)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:213, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    210 t1 = time.time()
    212 if cores > 1:
--> 213     results = run_chains_parallel(
    214         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores
    215     )
    216 else:
    217     results = run_chains_sequential(
    218         chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs
    219     )

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:388, in run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores)
    386 params = tuple(cloudpickle.dumps(p) for p in params)
    387 kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
--> 388 results = _starmap_with_kwargs(
    389     pool,
    390     to_run,
    391     [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)],
    392     repeat(kernel_kwargs),
    393 )
    394 results = tuple(cloudpickle.loads(r) for r in results)
    395 pool.close()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/site-packages/pymc/smc/sampling.py:415, in _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter)
    411 def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter):
    412     # Helper function to allow kwargs with Pool.starmap
    413     # Copied from https://stackoverflow.com/a/53173433/13311693
    414     args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter)
--> 415     return pool.starmap(_apply_args_and_kwargs, args_for_starmap)

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:375, in Pool.starmap(self, func, iterable, chunksize)
    369 def starmap(self, func, iterable, chunksize=None):
    370     '''
    371     Like `map()` method but the elements of the `iterable` are expected to
    372     be iterables as well and will be unpacked as arguments. Hence
    373     `func` and (a, b) becomes func(a, b).
    374     '''
--> 375     return self._map_async(func, iterable, starmapstar, chunksize).get()

File /opt/homebrew/Caskroom/mambaforge/base/envs/pymc_env/lib/python3.11/multiprocessing/pool.py:774, in ApplyResult.get(self, timeout)
    772     return self._value
    773 else:
--> 774     raise self._value

NameError: name 'HTML' is not defined"
}
jucor commented 7 months ago

Dang, the latter seems to be related to https://github.com/fastai/fastprogress/issues/32 and https://github.com/pymc-devs/pymc/issues/5855 and https://github.com/pymc-devs/pymc/issues/5980 , none of which seems to have had an actual resolution :/

jucor commented 7 months ago

A really ugly workaround is to call pm.sample_smc(..., progressbar=False) , which does not try to render the progressbar in the notebook and thus skips the error. But that means the user is flying completely blind while the sampler is running, which is not ideal.

jucor commented 7 months ago

I confirm that the problem with fastprogress only occurs with cores > 1, so it's definitely tied to the parallelism.

fastprogress works also fine standalone in a notebook.

Inspecting its code in https://github.com/fastai/fastprogress/blob/master/fastprogress/fastprogress.py#L104 confirms that it checks its import of HTML to make sure the widget works. So the way we serialize/unserialize, or parallelize, must screw it up somehow.

jucor commented 7 months ago

The way I understand it, both problems come down to the fact that the "manual" serialization is missing some symbols: the local op in one case, or the HTML object imported by fastprogress when it is itself imported in the other.

I'm not strong enough about closures and namespaces in Python to pinpoint exactly how to spot these missing symbols, capture them, and reserialize them, but I would bet good money it should be done at this point in the SMC sampler code: https://github.com/pymc-devs/pymc/blob/118be0f23782945dc03c5fb36d58d6ce4a1f619f/pymc/smc/sampling.py#L385

I'd be happy to work with anyone here to enrich this serialization!

jucor commented 7 months ago

OK, progress bar error figured out: it's not quite due to pickling, it's due to how fastprogress autodetects that it is run in a notebook and conditionally imports HTML or not. Upon instantiation, the fastprogress module is loaded from a function called from a notebook, thus imports the object HTML. However, when the job runs and tries to update the progress bar, it is in a separate process is not a notebook, thus the module has imported HTML. I have tried forcing the import, but the link to the notebook is broken and the pretty HTML bar is not updated. However, I have a fix for that issue: using the non-HTML progress bars, in the multicore setup. I'll make a PR for that.

So that'll be one of two problems sorted :)

jucor commented 7 months ago

PR opened to fix the progress bar bug :) That doesn't help with the requirement to put the as_op to a separate file, but at least it's one thing cleaner :)