pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.49k stars 1.97k forks source link

Register the overloads added by CustomDist in worker processes #7241

Open EliasRas opened 3 months ago

EliasRas commented 3 months ago

Description

Currently sample_smc can fail due to a NotImplementedError if it's used with a model defined usingCustomDist. If a CustomDist is used without dist parameter, the overloads for _logprob, _logcdf and _support_point are registered only in the main process.

This PR adds an initializer which registers the overloads in the worker processes of the pool used in sample_smc.

Related Issue

Checklist

Type of change


📚 Documentation preview 📚: https://pymc--7241.org.readthedocs.build/en/7241/

welcome[bot] commented 3 months ago

Thank You Banner] :sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

ricardoV94 commented 3 months ago

Hi @EliasRas, I'll need some time to review this one properly. Thanks for taking the initiative

EliasRas commented 1 month ago

Looks like I messed up by rebasing instead of merging and introduced plenty of unnecessary commits to this feature. Does it need to be fixed?

twiecki commented 1 month ago

Yes, that needs to be fixed, happens to everyone. One approach is to start clean and cherry-pick your commits.

ricardoV94 commented 1 month ago

I think it's more complicated than this. The following example has specific dispatch, but no RV that shows up in the graph:

import pymc as pm

def _logp(value, mu):
    return -((value - mu) ** 2)

def _dist(mu, size=None):
    return pm.Normal.dist(mu, 1, size=size)

with pm.Model():
    mu = pm.Normal("mu", 0)
    pm.Potential("term", pm.logp(pm.CustomDist.dist(mu, logp=_logp, dist=_dist), [1, 2]))
    pm.sample_smc(draws=6, cores=1)        

It also fails even with a single core

EliasRas commented 1 month ago

It also fails even with a single core

22e8f0bb4a02d874856438065efa7b3ef2645e13 did refactoring for sample_smc and I think that errors should now pop up even with single core since the sampling is always done in another process. Previously this was the case only when cores>1 since there were separate run_chains_parallel and run_chains_sequential.

ricardoV94 commented 1 month ago

Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist

ricardoV94 commented 1 month ago

Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer even for unrelated models without any sort of CustomDist

Okay it's something about the new progressbal and pycharm interactive python console. If I use from ipython/terminal it works. But also works in main for me?

ricardoV94 commented 1 month ago

I cannot reproduce a failure with your test locally (after avoding the pycharm issue) nor in a Colab environment: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing

Can you share more details about your environment/setup?

EliasRas commented 1 month ago

Can you share more details about your environment/setup?

I added the output of conda list to "PyMC version information" section of #7224. I'm running the code using VSCode if that matters. Do you need anything else?

Basically I followed the install instructions and the pull request tutorial when installing. Might have also pip installed a couple of extra packages here and there.

ricardoV94 commented 1 month ago

I added the output of conda list to "PyMC version information" section of https://github.com/pymc-devs/pymc/issues/7224. I'm running the code using VSCode if that matters. Do you need anything else?

We should have at least one person reproduce the problem because I cannot. It may be a VSCode environment issue. Ideally we wouldn't have to change the codebase

EliasRas commented 1 month ago

The test does fail without the changes when I run it from miniforge prompt though.

ricardoV94 commented 1 month ago

The test does fail without the changes when I run it from miniforge prompt though.

Not sure what miniforge prompt is, can we try to reproduce here on the CI then? Push just the test without the fixes into a new PR and well run it to see if we can reproduce

EliasRas commented 4 weeks ago

Is there anything that needs to be done here besides running the tests?

twiecki commented 4 weeks ago

Is there anything that needs to be done here besides running the tests?

Sorry for the delay, just kicked off tests.

codecov[bot] commented 4 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.45%. Comparing base (19be124) to head (997d730). Report is 42 commits behind head on main.

Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/pymc-devs/pymc/pull/7241/graphs/tree.svg?width=650&height=150&src=pr&token=JFuXtOJ4Cb&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs)](https://app.codecov.io/gh/pymc-devs/pymc/pull/7241?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) ```diff @@ Coverage Diff @@ ## main #7241 +/- ## ========================================== - Coverage 92.47% 92.45% -0.02% ========================================== Files 102 102 Lines 17187 17200 +13 ========================================== + Hits 15893 15903 +10 - Misses 1294 1297 +3 ``` | [Files](https://app.codecov.io/gh/pymc-devs/pymc/pull/7241?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) | Coverage Δ | | |---|---|---| | [pymc/smc/sampling.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7241?src=pr&el=tree&filepath=pymc%2Fsmc%2Fsampling.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9zbWMvc2FtcGxpbmcucHk=) | `99.31% <100.00%> (+0.10%)` | :arrow_up: | ... and [20 files with indirect coverage changes](https://app.codecov.io/gh/pymc-devs/pymc/pull/7241/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs)
lucianopaz commented 4 weeks ago

Thanks @EliasRas, I haven't been able to reproduce this yet but that's just because I'm in the middle of switching workstations and haven't gotten everything setup yet. Your fix looks fine to me and I understand what you identified as the cause of the issue: the dispatching mechanism isn't registering the logp and other methods to the dynamically created class. I think that this highlights a caveat in pymc's and pytensor's design: spawned processes may not have all the registered dispatch signatures as the main process. I imagine that this is mostly a problem on Windows, where multiprocessing can only spawn new processes whereas linux based systems will default to forks which in principle should copy over the memory contents of the main processes. I'm not sure what will happen under MacOS because I think that they cannot use fork multiprocessing for some reason either. With this design caveat in hand, I'm not sure if it's better to have a package level utility function that serves as a sort of book-keeper or something that can handle communicating the extra dispatch registration needed to ensure that child processes will use the correct dispatching functions. I'm curious to know what @ricardoV94 thinks about this. I don't think that this PR should have to tackle this kind of work, but I think that we can discuss if it's necessary here, and maybe later open an issue and a separate PR (also maybe in pytensor where dispatching is used for transpilation/compilation and maybe at some point for lazy gradients?).

aseyboldt commented 4 weeks ago

I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box? Sounds like for some reason the dispatch functions don't get registered when the object is unpickled. But wouldn't it be cleaner to overwrite the pickling behavior of this class then? We could override __getstate__ and __setstate__ methods to that effect?

lucianopaz commented 4 weeks ago

I guess the underlying reason for the failure is that pickling of DensityDist doesn't work out of the box?

I don’t think the problem is about pickling. The DensityDist end up returning an op that can be cloudpickled. If I recall correctly it can’t be pickled because the op class is created on the fly. In the process of creating the op, the dispatchers get populated with the callables that are supplied as inputs to the distribution class. As far as I understand, those functions are detached from the rv op and that’s why they never get populated on a spawned process.

aseyboldt commented 4 weeks ago

I don't mean that the pickling itself throws an error (it doesn't), but that it would be the responsibility of the DensityDist object to ensure that the set-up it needs (ie registering the logp) is done when it is unpickled.

For instance the following fails with the NotImplementedError, and has nothing to do with smc, so I guess the solution shouldn't be specific to smc?

import pymc as pm
import cloudpickle
import multiprocessing

def use_logp_func(pickled_model):
    model = cloudpickle.loads(pickled_model)
    logp = model.logp()
    func = pm.pytensorf.compile_pymc(model.value_vars, logp)
    print(func(1.0))

if __name__ == "__main__":
    with pm.Model() as model:

        def logp(value):
            return -(value**2)

        pm.DensityDist("x", logp=logp)

    logp = model.logp()
    func = pm.pytensorf.compile_pymc(model.value_vars, logp)
    pickled_model = cloudpickle.dumps(model)

    ctx = multiprocessing.get_context("spawn")
    process = ctx.Process(target=use_logp_func, args=(pickled_model,))
    process.start()
    process.join()
lucianopaz commented 4 weeks ago

I completely agree that this problem isn’t unique to smc and is a design caveat that needs to be addressed more comprehensively. I think that we can kind of patch some things:

  1. Make Model objects __setstate__ and __getstate__ repopulate the dispatch registries
  2. Get CustomDist rv ops to have these methods defined somehow (maybe clojures) that repopulates the dispatch registries.

I’m not sure if these two methods can cover all use patterns though.

ricardoV94 commented 4 weeks ago

Alternatively we could pass the functions needed to each process which is more like what pm.sample does.

This also avoids recompiling the same functions multiple times?

EliasRas commented 2 weeks ago

@lucianopaz Point 1. is pretty straightforward but could you explain what you meant by 2.? How would it be different from overriding __getstate__ and __setstate__?