mckinsey / causalnex

A Python library that helps data scientists to infer causation rather than observing correlation.
http://causalnex.readthedocs.io/
Other
2.24k stars 258 forks source link

Incorrect probabilities learnt for Nodes with one parent only #35

Closed GabrielAzevedoFerreiraQB closed 3 years ago

GabrielAzevedoFerreiraQB commented 4 years ago

Description

If a Node has only one parent (e.g. A->B) this node is always assigned to the flat distribution when we fit the probabilities.

I dig in and found out that problem turns out to come from PGMPY. I will raise the same issue there too, but am not sure how we want to handle it in CausalNex in the meantime.

Steps to Reproduce

import numpy as np
import pandas as pd
from causalnex.structure import StructureModel
from causalnex.network import BayesianNetwork
sm = StructureModel()
sm.add_edge('A','B')
np.random.seed(11)
vals = [1,2,3]
A = np.random.choice(vals,size=3000,p=[.1,.3,.6])
B = [np.random.choice(vals,p=[.9,.05,.05]) if a==1 else # you can put any values here, the result will be the same
     np.random.choice(vals,p=[.1,.2,.7]) if a==2 else
     np.random.choice(vals,p=[.85,.1,.05]) if a==3 else
     np.random.choice(vals) for a in A]

df = pd.DataFrame([A,B],index=['A','B']).T
#####
bn = BayesianNetwork(sm)
bn = bn.fit_node_states(df)
bn = bn.fit_cpds(df, method="MaximumLikelihoodEstimator")
print(bn.cpds['B'].round(decimals=2))

Expected Result

A     1     2     3
B                  
1  0.92  0.11  0.85
2  0.02  0.21  0.10
3  0.06  0.68  0.05

Actual Result

A     1     2     3
B                  
1  0.33  0.33  0.33
2  0.33  0.33  0.33
3  0.33  0.33  0.33

Your Environment

CAUSE:

This comes is from PGMPY, precisely file pgmpy/estimators/base.py, ~ line 127.

parents_states = [self.state_names[parent] for parent in parents]
state_count_data = data.groupby([variable] + parents).size().unstack(parents)

row_index = self.state_names[variable]
column_index = pd.MultiIndex.from_product(parents_states, names=parents)
state_counts = state_count_data.reindex(index=row_index, columns=column_index).fillna(0) # <----Where the error is

If the node has more than one parent, state_count_data columns will be MultiIndex from the start. So doing state_count_data.reindex(...,columns=column_index) causes no problem.

If the node has one single parent, however, state_count_data columns will not be MultiIndex, but just "normal" indexing. In that case, when doing state_count_data.reindex(...,columns=column_index) the result is dataframe full of NAs.

Dirty solution:

convert state_count_data.columns to Multiindex before reindexing

parents_states = [self.state_names[parent] for parent in parents]
state_count_data = data.groupby([variable] + parents).size().unstack(parents)

row_index = self.state_names[variable]
if len(parents) == 1: ## ADD THIS IF CONDITION
    state_count_data.columns = pd.MultiIndex.from_product(list(state_count_data.columns), names=parents)
column_index = pd.MultiIndex.from_product(parents_states, names=parents)
state_counts = state_count_data.reindex(index=row_index, columns=column_index).fillna(0)
GabrielAzevedoFerreiraQB commented 4 years ago

Raised an issue on pgmpy as well

https://github.com/pgmpy/pgmpy/issues/1252

benhorsburgh commented 4 years ago

This appears to have been fixed in pgmpy=0.1.9 I'll investigate bumping version and will update soon. For now though, you may be able to update pgmpy version.

GabrielAzevedoFerreiraQB commented 4 years ago

Just a note: should we update the requirement in causalnex to pgmpy=0.1.9 ?

oentaryorj commented 3 years ago

I saw pgmpy>=0.1.12, <0.2.0 in requirements.txt now. Does this resolve the issue?

PNEPNE commented 3 years ago

Yes, now the actual result is same as expected result. Please see my notebook below- issue 35 notebook