Open gkericks opened 5 years ago
Hi,
The half normal distribution is already implemented in the standard-variables library. It would be nice to implement a more general truncated normal, I'll probably do that soon.
Brancher allows you to truncate a arbitrary probabilistic model using a arbitrary boolean rule. You need to import the truncate_model function from the transformations module.
Unfortunately this function is still not documented and the public interface is not very user friendly at the moment.
It takes as input: 1) model: The probabilistic model you want to truncate 2) model_statistics: This is a function that takes a sample as input and it returns some values. 3) truncation_rule: This is a function that takes the output of the model_statistics as input and returns a boolean indicating if the sample should be kept or not.
Here is how to use it for creating a truncated Normal: a = NormalVariable(0,1,"a") model = ProbabilisticModel([a]) model_statistics = lambda sample: sample[a] #This extracts the value of the variable a truncation_rule = lambda x: x < 2 # This says that a sample should be rejected if the value of a is bigger or equal than 2 truncated_model = truncate_model function(model, truncation_rule = truncation_rule, model_statistics = model_statistics)
Thanks! I'm still verifying that the workaround works for my application but it looks promising.
@LucaAmbrogioni So I have been able to implement my solution using this truncation and while it does work (for inference) It possibly introduces some inefficiencies into getting the means and variance of my posterior distribution (maybe I am simply not doing this the best way).
So what I am trying to do, is that after inference I need to compute the mean and stddev of two of my variables. Computing time is a resource here and making that computation fast would be helpful. I tired the model.get_mean() method but the model appears to not have an analytical solution that you support. So I have been sampling from the posterior and computing the mean and std from those samples and here is where the truncation causes an issue: If the sample amount is small the model._get_posterior_sample(n)
call will spin forever and not return. I am trying to limit the size of the sample because of computing time but a small sample also seems to break sometimes. I think this might be that the truncation rule could possible reject all samples and something deep in the brancher code will spin its wheels but I'm not 100% sure.
Is there a better way to compute those means (they are normal variables)? Or is the official truncated normal distribution you mentioned on a road-map now?
Thanks!
For more insight, here is an example of the execution stack when I interrupt:
~/.local/lib/python3.6/site-packages/brancher/transformations.py in truncated_get_sample(number_samples, max_itr, **kwargs)
32 itr = 0
33 while (current_number_samples < number_samples and itr < max_itr) or current_number_samples < 1:
---> 34 remaining_samples, n, p = reject_samples(model._get_sample(batch_size, **kwargs),
35 model_statistics=model_statistics,
36 truncation_rule=truncation_rule)
~/.local/lib/python3.6/site-packages/brancher/variables.py in _get_sample(self, number_samples, observed, input_values, differentiable)
1152 for var in self._input_variables])
1153 joint_sample.update(input_values)
-> 1154 self.reset()
1155 return joint_sample
1156
~/.local/lib/python3.6/site-packages/brancher/variables.py in reset(self)
1474 """
1475 for var in self.flatten():
-> 1476 var.reset(recursive=False)
1477
1478 def _flatten(self):
~/.local/lib/python3.6/site-packages/brancher/variables.py in reset(self, recursive)
554 return {self: value}
555
--> 556 def reset(self, recursive=False):
557 """
558 Method. It recursively resets the self._evaluated and self._current_value attributes of the variable
I believe there is an infinite loop happening between the recursive calls.
Hi, I am looking to implement a probabilistic model in Brancher that requires a truncated normal distribution. In pymc3 that can be achieved with https://docs.pymc.io/api/bounds.html as well as an explicit TruncatedNormal class. Is there an equivalent or some sort of work-around in Brancher? Thanks