AI-DI / Brancher

A user-centered Python package for differentiable probabilistic inference
https://brancher.org/
MIT License
203 stars 30 forks source link

Bounded variables? #20

Open gkericks opened 5 years ago

gkericks commented 5 years ago

Hi, I am looking to implement a probabilistic model in Brancher that requires a truncated normal distribution. In pymc3 that can be achieved with https://docs.pymc.io/api/bounds.html as well as an explicit TruncatedNormal class. Is there an equivalent or some sort of work-around in Brancher? Thanks

LucaAmbrogioni commented 5 years ago

Hi,

The half normal distribution is already implemented in the standard-variables library. It would be nice to implement a more general truncated normal, I'll probably do that soon.

Brancher allows you to truncate a arbitrary probabilistic model using a arbitrary boolean rule. You need to import the truncate_model function from the transformations module.

Unfortunately this function is still not documented and the public interface is not very user friendly at the moment.

It takes as input: 1) model: The probabilistic model you want to truncate 2) model_statistics: This is a function that takes a sample as input and it returns some values. 3) truncation_rule: This is a function that takes the output of the model_statistics as input and returns a boolean indicating if the sample should be kept or not.

Here is how to use it for creating a truncated Normal: a = NormalVariable(0,1,"a") model = ProbabilisticModel([a]) model_statistics = lambda sample: sample[a] #This extracts the value of the variable a truncation_rule = lambda x: x < 2 # This says that a sample should be rejected if the value of a is bigger or equal than 2 truncated_model = truncate_model function(model, truncation_rule = truncation_rule, model_statistics = model_statistics)

gkericks commented 5 years ago

Thanks! I'm still verifying that the workaround works for my application but it looks promising.

gkericks commented 5 years ago

@LucaAmbrogioni So I have been able to implement my solution using this truncation and while it does work (for inference) It possibly introduces some inefficiencies into getting the means and variance of my posterior distribution (maybe I am simply not doing this the best way).

So what I am trying to do, is that after inference I need to compute the mean and stddev of two of my variables. Computing time is a resource here and making that computation fast would be helpful. I tired the model.get_mean() method but the model appears to not have an analytical solution that you support. So I have been sampling from the posterior and computing the mean and std from those samples and here is where the truncation causes an issue: If the sample amount is small the model._get_posterior_sample(n) call will spin forever and not return. I am trying to limit the size of the sample because of computing time but a small sample also seems to break sometimes. I think this might be that the truncation rule could possible reject all samples and something deep in the brancher code will spin its wheels but I'm not 100% sure.

Is there a better way to compute those means (they are normal variables)? Or is the official truncated normal distribution you mentioned on a road-map now?

Thanks!

gkericks commented 5 years ago

For more insight, here is an example of the execution stack when I interrupt:

~/.local/lib/python3.6/site-packages/brancher/transformations.py in truncated_get_sample(number_samples, max_itr, **kwargs)
     32         itr = 0
     33         while (current_number_samples < number_samples and itr < max_itr) or current_number_samples < 1:
---> 34             remaining_samples, n, p = reject_samples(model._get_sample(batch_size, **kwargs),
     35                                                      model_statistics=model_statistics,
     36                                                      truncation_rule=truncation_rule)

~/.local/lib/python3.6/site-packages/brancher/variables.py in _get_sample(self, number_samples, observed, input_values, differentiable)
   1152                                         for var in self._input_variables])
   1153         joint_sample.update(input_values)
-> 1154         self.reset()
   1155         return joint_sample
   1156 

~/.local/lib/python3.6/site-packages/brancher/variables.py in reset(self)
   1474         """
   1475         for var in self.flatten():
-> 1476             var.reset(recursive=False)
   1477 
   1478     def _flatten(self):

~/.local/lib/python3.6/site-packages/brancher/variables.py in reset(self, recursive)
    554             return {self: value}
    555 
--> 556     def reset(self, recursive=False):
    557         """
    558         Method. It recursively resets the self._evaluated and self._current_value attributes of the variable

I believe there is an infinite loop happening between the recursive calls.