Open EliasRas opened 3 months ago
]
:sparkling_heart: Thanks for opening this pull request! :sparkling_heart: The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.
Hi @EliasRas, I'll need some time to review this one properly. Thanks for taking the initiative
Looks like I messed up by rebasing instead of merging and introduced plenty of unnecessary commits to this feature. Does it need to be fixed?
Yes, that needs to be fixed, happens to everyone. One approach is to start clean and cherry-pick your commits.
I think it's more complicated than this. The following example has specific dispatch, but no RV that shows up in the graph:
import pymc as pm
def _logp(value, mu):
return -((value - mu) ** 2)
def _dist(mu, size=None):
return pm.Normal.dist(mu, 1, size=size)
with pm.Model():
mu = pm.Normal("mu", 0)
pm.Potential("term", pm.logp(pm.CustomDist.dist(mu, logp=_logp, dist=_dist), [1, 2]))
pm.sample_smc(draws=6, cores=1)
It also fails even with a single core
It also fails even with a single core
22e8f0bb4a02d874856438065efa7b3ef2645e13 did refactoring for sample_smc
and I think that errors should now pop up even with single core since the sampling is always done in another process. Previously this was the case only when cores>1
since there were separate run_chains_parallel
and run_chains_sequential
.
Somehow, in main, I am getting ConnectionResetError: [Errno 104] Connection reset by peer
even for unrelated models without any sort of CustomDist
Somehow, in main, I am getting
ConnectionResetError: [Errno 104] Connection reset by peer
even for unrelated models without any sort of CustomDist
Okay it's something about the new progressbal and pycharm interactive python console. If I use from ipython/terminal it works. But also works in main for me?
I cannot reproduce a failure with your test locally (after avoding the pycharm issue) nor in a Colab environment: https://colab.research.google.com/drive/1I1n6c9IlmXknIfhxC5s7sAQghv0vfRSY?usp=sharing
Can you share more details about your environment/setup?
Can you share more details about your environment/setup?
I added the output of conda list
to "PyMC version information" section of #7224. I'm running the code using VSCode if that matters. Do you need anything else?
Basically I followed the install instructions and the pull request tutorial when installing. Might have also pip install
ed a couple of extra packages here and there.
I added the output of conda list to "PyMC version information" section of https://github.com/pymc-devs/pymc/issues/7224. I'm running the code using VSCode if that matters. Do you need anything else?
We should have at least one person reproduce the problem because I cannot. It may be a VSCode environment issue. Ideally we wouldn't have to change the codebase
The test does fail without the changes when I run it from miniforge prompt though.
The test does fail without the changes when I run it from miniforge prompt though.
Not sure what miniforge prompt is, can we try to reproduce here on the CI then? Push just the test without the fixes into a new PR and well run it to see if we can reproduce
Is there anything that needs to be done here besides running the tests?
Is there anything that needs to be done here besides running the tests?
Sorry for the delay, just kicked off tests.
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 92.45%. Comparing base (
19be124
) to head (997d730
). Report is 42 commits behind head on main.
Thanks @EliasRas, I haven't been able to reproduce this yet but that's just because I'm in the middle of switching workstations and haven't gotten everything setup yet.
Your fix looks fine to me and I understand what you identified as the cause of the issue: the dispatching mechanism isn't registering the logp and other methods to the dynamically created class. I think that this highlights a caveat in pymc
's and pytensor
's design: spawned processes may not have all the registered dispatch signatures as the main process. I imagine that this is mostly a problem on Windows, where multiprocessing can only spawn new processes whereas linux based systems will default to forks which in principle should copy over the memory contents of the main processes. I'm not sure what will happen under MacOS because I think that they cannot use fork multiprocessing for some reason either.
With this design caveat in hand, I'm not sure if it's better to have a package level utility function that serves as a sort of book-keeper or something that can handle communicating the extra dispatch registration needed to ensure that child processes will use the correct dispatching functions. I'm curious to know what @ricardoV94 thinks about this. I don't think that this PR should have to tackle this kind of work, but I think that we can discuss if it's necessary here, and maybe later open an issue and a separate PR (also maybe in pytensor
where dispatching is used for transpilation/compilation and maybe at some point for lazy gradients?).
I guess the underlying reason for the failure is that pickling of DensityDist
doesn't work out of the box? Sounds like for some reason the dispatch functions don't get registered when the object is unpickled. But wouldn't it be cleaner to overwrite the pickling behavior of this class then? We could override __getstate__
and __setstate__
methods to that effect?
I guess the underlying reason for the failure is that pickling of
DensityDist
doesn't work out of the box?
I don’t think the problem is about pickling. The DensityDist
end up returning an op that can be cloudpickled. If I recall correctly it can’t be pickled because the op class is created on the fly. In the process of creating the op, the dispatchers get populated with the callables that are supplied as inputs to the distribution class. As far as I understand, those functions are detached from the rv op and that’s why they never get populated on a spawned process.
I don't mean that the pickling itself throws an error (it doesn't), but that it would be the responsibility of the DensityDist
object to ensure that the set-up it needs (ie registering the logp) is done when it is unpickled.
For instance the following fails with the NotImplementedError
, and has nothing to do with smc, so I guess the solution shouldn't be specific to smc?
import pymc as pm
import cloudpickle
import multiprocessing
def use_logp_func(pickled_model):
model = cloudpickle.loads(pickled_model)
logp = model.logp()
func = pm.pytensorf.compile_pymc(model.value_vars, logp)
print(func(1.0))
if __name__ == "__main__":
with pm.Model() as model:
def logp(value):
return -(value**2)
pm.DensityDist("x", logp=logp)
logp = model.logp()
func = pm.pytensorf.compile_pymc(model.value_vars, logp)
pickled_model = cloudpickle.dumps(model)
ctx = multiprocessing.get_context("spawn")
process = ctx.Process(target=use_logp_func, args=(pickled_model,))
process.start()
process.join()
I completely agree that this problem isn’t unique to smc and is a design caveat that needs to be addressed more comprehensively. I think that we can kind of patch some things:
Model
objects __setstate__
and __getstate__
repopulate the dispatch registriesCustomDist
rv ops to have these methods defined somehow (maybe clojures) that repopulates the dispatch registries.I’m not sure if these two methods can cover all use patterns though.
Alternatively we could pass the functions needed to each process which is more like what pm.sample
does.
This also avoids recompiling the same functions multiple times?
@lucianopaz
Point 1. is pretty straightforward but could you explain what you meant by 2.? How would it be different from overriding __getstate__
and __setstate__
?
Description
Currently
sample_smc
can fail due to aNotImplementedError
if it's used with a model defined usingCustomDist
. If aCustomDist
is used withoutdist
parameter, the overloads for_logprob
,_logcdf
and_support_point
are registered only in the main process.This PR adds an initializer which registers the overloads in the worker processes of the pool used in
sample_smc
.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7241.org.readthedocs.build/en/7241/