jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 590 forks source link

probability() and log_probability() won't accept "Masked" tensors #1025

Open jdschmitt11 opened 1 year ago

jdschmitt11 commented 1 year ago

The probability and log_probability functions won't accept masked tensors so they can't be evaluated with missing evidence. When I use the same masked tensor for predict_proba() it works fine.

IndexError Traceback (most recent call last) Cell In[26], line 1 ----> 1 model.probability(evidence_masked)

File Miniconda3\envs\bayesian-api\lib\site-packages\pomegranate\distributions_distribution.py:61, in Distribution.probability(self, X) 60 def probability(self, X): ---> 61 return torch.exp(self.log_probability(X))

File Miniconda3\envs\bayesian-api\lib\site-packages\pomegranate\bayesian_network.py:352, in BayesianNetwork.logprobability(self, X) 349 if len(parents) > 1: 350 X = X_.unsqueeze(-1) --> 352 logps += distribution.logprobability(X) 354 return logps

File Miniconda3\envs\bayesian-api\lib\site-packages\pomegranate\distributions\categorical.py:175, in Categorical.log_probability(self, X) 173 logps = torch.zeros(X.shape[0], dtype=self.probs.dtype) 174 for i in range(self.d): --> 175 logps += self._log_probs[i][X[:, i]] 177 return logps

File Miniconda3\envs\bayesian-api\lib\site-packages\torch\masked\maskedtensor\core.py:274, in MaskedTensor.__torch_function__(cls, func, types, args, kwargs) 272 return NotImplemented 273 with torch._C.DisableTorchFunctionSubclass(): --> 274 ret = func(*args, *kwargs) ... 500 # We save the function ptr as the op attribute on 501 # OpOverloadPacket to access it here. --> 502 return self._op(args, **kwargs or {})

IndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [28] at index 0

jmschrei commented 1 year ago

Thanks for the issue. What are you hoping the return to be? Would you like it to infer the missing states and then calculate log probabilities using those, or only calculate log probabilities for those states that have complete observations? I think the second would lead to inflated values when examples have missing values.

jdschmitt11 commented 1 year ago

I would want the second one which I believe is what is returned for version 0.14.x. One of my uses is that I am calculating conflict score --> Conflict = log((P(e1)P(e2)...P(ei)) / P(e)) for anomaly checks so for given set of available evidence I would like P(e). For the marginal probabilities, in the numerator I am using your predict proba() with all evidence being masked and finding the corresponding P(e1), P(e2), etc. I believe, however, that providing probability()/log_probility() with evidence for one node only with the rest being masked would also provide the marginal probability as well, so P(e3) = P(e1='masked', e2='masked', e3).

On Wed, Apr 19, 2023 at 1:27 PM Jacob Schreiber @.***> wrote:

Thanks for the issue. What are you hoping the return to be? Would you like it to infer the missing states and then calculate log probabilities using those, or only calculate log probabilities for those states that have complete observations? I think the second would lead to inflated values when examples have missing values.

— Reply to this email directly, view it on GitHub https://github.com/jmschrei/pomegranate/issues/1025#issuecomment-1515106501, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHPMF2PNVYQJ6HQWMG5BQLXCAOBRANCNFSM6AAAAAAXECLJFQ . You are receiving this because you authored the thread.Message ID: @.***>

jdschmitt11 commented 1 year ago

Sorry disregard my last sentence "I believe, however, that providing probability()/log_probility() with evidence for one node only with the rest being masked would also provide the marginal probability as well, so P(e3) = P(e1='masked', e2='masked', e3)". I wasn't able to reproduce this in version 0.14.

On Wed, Apr 19, 2023 at 1:27 PM Jacob Schreiber @.***> wrote:

Thanks for the issue. What are you hoping the return to be? Would you like it to infer the missing states and then calculate log probabilities using those, or only calculate log probabilities for those states that have complete observations? I think the second would lead to inflated values when examples have missing values.

— Reply to this email directly, view it on GitHub https://github.com/jmschrei/pomegranate/issues/1025#issuecomment-1515106501, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHPMF2PNVYQJ6HQWMG5BQLXCAOBRANCNFSM6AAAAAAXECLJFQ . You are receiving this because you authored the thread.Message ID: @.***>

jdschmitt11 commented 1 year ago

I had some time to take a look at this and was able to get agreement to V 0.14 by inserting this code between lines 345 and 354 in your bayesian_network.py:

if X._masked_mask[0, i] == True:

  parents = model._parents[i] + (i,)
  X_ = X._masked_data[:, parents]

  if len(parents) > 1:
      X_ = X_.unsqueeze(-1)

  logps += distribution.log_probability(X_)

The problem with this is that is forces log_probability to only take masked tensors, similar to predict_proba() and predict().