pymc-devs / pymc4

(Deprecated) Experimental PyMC interface for TensorFlow Probability. Official work on this project has been discontinued.
Apache License 2.0
711 stars 113 forks source link

Added pm.Deterministic class objects to sample state #311

Closed semohr closed 4 years ago

semohr commented 4 years ago

Hi, I tried to overload the pm.Distributions classes to add some extra functionality i.e. add some variables which can be defined for plotting or other purposes later.

For example:

class my_Normal(pm.Normal):
    def __init__(self,*args,**kwargs):
        if "foo" in kwargs:
            self.foo = kwargs.get("foo")
        super().__init__(*args,**kwargs)
#And later access it via
_, state = pm.evaluate_model(model_with_my_Normal)

#Now we can access the foo attribute via
state.distributions["model_name/my_Normal_name"].foo #This Works!

This is all nice and works but if we try to apply the same concept to the pm.Deterministic class we find that state.deterministics["model_name/my_Normal_name"] is a tensor and the object is not added to state.distributions as well. That is intended behaviour, so I added an extra dict in the sampling state for deterministic "distributions", which allows me to do:

class my_Deterministic(pm.Deterministic):
    def __init__(self,*args,**kwargs):
        if "foo" in kwargs:
            self.foo = kwargs.get("foo")
        super().__init__(*args,**kwargs)

_, state = pm.evaluate_model(model_with_my_Deterministic)
state.deterministic_distributions["model_name/my_Deterministic_name"].foo #This works now too!

This all feels a bit hacky and I'm not too sure if this breaks anything but I found it quite useful. If you think this could be a feature feel free to merge the pull request or change anything you see fit. Otherwise just reject/close the request ;)

Best wishes and keep up the good work, Sebastian

codecov[bot] commented 4 years ago

Codecov Report

Merging #311 into master will increase coverage by 0.02%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #311      +/-   ##
==========================================
+ Coverage   90.02%   90.05%   +0.02%     
==========================================
  Files          33       33              
  Lines        2447     2453       +6     
==========================================
+ Hits         2203     2209       +6     
  Misses        244      244              
Impacted Files Coverage Δ
pymc4/coroutine_model.py 89.83% <ø> (ø)
pymc4/flow/executor.py 94.40% <100.00%> (+0.12%) :arrow_up:
pymc4/flow/meta_executor.py 87.50% <100.00%> (ø)
pymc4/forward_sampling.py 97.87% <100.00%> (ø)
pymc4/inference/sampling.py 93.40% <100.00%> (ø)
pymc4/inference/utils.py 95.23% <100.00%> (ø)
pymc4/variational/approximations.py 91.20% <100.00%> (ø)
junpenglao commented 4 years ago

This looks good to me. @fonnesbeck could we go ahead and merge?

fonnesbeck commented 4 years ago

Thanks for this!