jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.37k stars 590 forks source link

Probability Method for BayesianNetwork Raises Error in Documentation Example #697

Closed a3huang closed 4 years ago

a3huang commented 4 years ago

I am using version 0.12.0 of pomegranate.

Following the example in the Bayesian Networks section of the documentation, I build the graphical model for the Monty Hall example exactly as written. However, when I run:

model.probability(['A', 'B', 'C'])

I get the following error:

TypeError: list indices must be integers or slices, not tuple

To Reproduce

from pomegranate import *

guest = DiscreteDistribution({'A': 1./3, 'B': 1./3, 'C': 1./3})
prize = DiscreteDistribution({'A': 1./3, 'B': 1./3, 'C': 1./3})
monty = ConditionalProbabilityTable(
        [['A', 'A', 'A', 0.0],
         ['A', 'A', 'B', 0.5],
         ['A', 'A', 'C', 0.5],
         ['A', 'B', 'A', 0.0],
         ['A', 'B', 'B', 0.0],
         ['A', 'B', 'C', 1.0],
         ['A', 'C', 'A', 0.0],
         ['A', 'C', 'B', 1.0],
         ['A', 'C', 'C', 0.0],
         ['B', 'A', 'A', 0.0],
         ['B', 'A', 'B', 0.0],
         ['B', 'A', 'C', 1.0],
         ['B', 'B', 'A', 0.5],
         ['B', 'B', 'B', 0.0],
         ['B', 'B', 'C', 0.5],
         ['B', 'C', 'A', 1.0],
         ['B', 'C', 'B', 0.0],
         ['B', 'C', 'C', 0.0],
         ['C', 'A', 'A', 0.0],
         ['C', 'A', 'B', 1.0],
         ['C', 'A', 'C', 0.0],
         ['C', 'B', 'A', 1.0],
         ['C', 'B', 'B', 0.0],
         ['C', 'B', 'C', 0.0],
         ['C', 'C', 'A', 0.5],
         ['C', 'C', 'B', 0.5],
         ['C', 'C', 'C', 0.0]], [guest, prize])

s1 = Node(guest, name="guest")
s2 = Node(prize, name="prize")
s3 = Node(monty, name="monty")

model = BayesianNetwork("Monty Hall Problem")
model.add_states(s1, s2, s3)
model.add_edge(s1, s3)
model.add_edge(s2, s3)
model.bake()

model.probability(['A', 'B', 'C'])
KGtk-git commented 4 years ago

I'm also getting this issue. Would be great if someone fix this ASAP

jmschrei commented 4 years ago

Sorry about the issue. You have to pass in 2D numpy arrays now. I'll fix the documentation soon.

jmschrei commented 4 years ago

The documentation is fixed in the next patch and pomegranate will raise an error when you pass in an array of the wrong size.