MaxHalford / sorobn

🧮 Bayesian networks in Python
https://sorobn.streamlit.app
MIT License
234 stars 34 forks source link

CPT's require parent nodes in alphabetic order #19

Closed danieljaeck closed 3 years ago

danieljaeck commented 3 years ago

Hi Max,

thanks again for putting together hedgehog - great library with a great API.

I think there might be a small hiccup (or design choice?) in BayesNet.prepare(). Consider the following example:

import hedgehog as hh
import pandas as pd

edges = pd.DataFrame(
    {
        "Parent Node": ["Input A", "Input B"],
        "Child Node": ["Output", "Output"],
    }
)

bn = hh.BayesNet(*edges.itertuples(index=False, name=None))

bn.P['Input A'] = pd.Series({True: 0.7, False: 0.3})
bn.P['Input B'] = pd.Series({True: 0.4, False: 0.6})

# Manual input of a CPT with columns NOT ordered alphabetically
output_cpt = pd.DataFrame(
    {
        "Input B": [True, True, True, True, False, False, False, False],
        "Input A": [True, True, False, False, True, True, False, False],
        "Output": [
            True,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
        ],
        "Prob": [1, 0, 0, 1, 0.5, 0.5, 0.001, 0.999],
    }
)
bn.P["Output"] = output_cpt.set_index(["Input B", "Input A", "Output"])["Prob"]

If we take a look at bn.P["Output"], we see that the probabilities are integrated correctly.

Input B  Input A  Output
True     True     True      1.000
                  False     0.000
         False    True      0.000
                  False     1.000
False    True     True      0.500
                  False     0.500
         False    True      0.001
                  False     0.999
Name: Prob, dtype: float64

Next, we call prepare():

bn.prepare()

After calling prepare(), we see that the index is sorted (which is completely fine) and the column names are mixed up during renaming, since the parent node names are stored alphabetically as attributes during __init__().

Input A  Input B  Output
False    False    False     0.999
                  True      0.001
         True     False     0.500
                  True      0.500
True     False    False     1.000
                  True      0.000
         True     False     0.000
                  True      1.000
Name: P(Output | Input A, Input B), dtype: float64

A quick workaround for this behaviour is to sort the parent nodes in the input CPT alphabetically. Nevertheless, this is not perfectly intuitive, and I think a better design choice would be to require a series with index names, so an explicit reference can be made. Unfortunately, I am also not able to use the query() method with several nodes specified in event without running prepare() beforehand (e.g., bn.query("Output", event={"Input A": True, "Input B": False}) will give ValueError: cannot join with no overlapping index names before running bn.prepare()).

Happy to get your feedback on this!

Best, Daniel

MaxHalford commented 3 years ago

Hello! Thanks for your interest. I read through the issue and ran the code.

What exactly is the problem? Would you prefer for the column names not to be sorted during prepare?

danieljaeck commented 3 years ago

Hi Max, thanks for getting back to me so quickly!

Well, it's not really a preference, but the column names are assigned incorrectly (Input B is named Input A and Input A is named Input B during prepare). This leads to incorrect conditional probabilities.

Considering the CPT from above, the marginal probability for Output given {"Input B": False, "Input A": True} should be the following:

Output
False    0.5
True     0.5
Name: P(Output), dtype: float64

Instead, after running prepare, where the column names are mixed up, I get the following:

>>> bn.query("Output", event={"Input B": False, "Input A": True})
Output
False    1.0
Name: P(Output), dtype: float64

Hope this clarifies the matter.

Cheers!

MaxHalford commented 3 years ago

By golly you're right.

I've added some logic to handle the case where the CPT has index names. So you can either provide it without names or with names. This provides the most flexibility. Forcing users to provide name the indexes of their CPTs could be worthwhile, but I would say that it's too restrictive.

Thanks for the detailed code examples! I've added them as a unit test.

danieljaeck commented 3 years ago

Great, this looks perfect - thanks so much!