py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
7.01k stars 922 forks source link

Error handling column names in CIT as used by CausalModel.graph_refute method #949

Open arainboldt opened 1 year ago

arainboldt commented 1 year ago

This code produces an indexing error when calling the conditional independence test method via the graph_refute method.

from dowhy.causal_model import CausalModel
import networkx as nx

graph_df = pd.DataFrame(np.random.binomial(2, [.2,.3,.2,.5], size=(100, 4)), 
                        index=range(100), 
                        columns=['Foo', 'Bar','treatment', 'outcome'])

edge_list = [
    ('Foo', 'Bar'),
    ('Foo', 'outcome'),
    ('Bar', 'outcome'),
    ('treatment', 'outcome'),

]

expert_G = nx.DiGraph()
for edge in edge_list:
    expert_G.add_edge(edge[0],edge[1])

dot_graph = p = nx.drawing.nx_pydot.to_pydot(expert_G)

model = CausalModel(
    data=encode_data(graph_df[list(expert_G.nodes())]),
    treatment='treatment',
    outcome='outcome',
    graph=str(dot_graph)
)

independence_refutation_result = model.refute_graph()

Error:


KeyError                                  Traceback (most recent call last)
Cell In[124], line 29
     20 dot_graph = p = nx.drawing.nx_pydot.to_pydot(expert_G)
     22 model = CausalModel(
     23     data=encode_data(graph_df[list(expert_G.nodes())]),
     24     treatment='treatment',
     25     outcome='outcome',
     26     graph=str(dot_graph)
     27 )
---> 29 independence_refutation_result = model.refute_graph()

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/causal_model.py:518](https://file+.vscode-resource.vscode-cdn.net/home/**/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/causal_model.py:518), in CausalModel.refute_graph(self, k, independence_test, independence_constraints)
    514                 conditional_independences.append([a, b, k_list])
    516     independence_constraints = conditional_independences
--> 518 res = refuter.refute_model(independence_constraints=independence_constraints)
    520 self.logger.info(refuter._refutation_passed)
    522 return res

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/causal_refuters/graph_refuter.py:108](https://file+.vscode-resource.vscode-cdn.net/home/**/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/causal_refuters/graph_refuter.py:108), in GraphRefuter.refute_model(self, independence_constraints)
    105 elif a in discrete_columns and b in discrete_columns and all(node in discrete_columns for node in c):
    106     # a, b and c are all discrete variables
    107     if self._method_name_discrete is None or self._method_name_discrete == "conditional_mutual_information":
--> 108         self.conditional_mutual_information(x=a, y=b, z=c)
    109     else:
    110         self.logger.error(
    111             "Invalid conditional independence test for discrete data. Supported tests - conditional_mutual_information"
    112         )

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/causal_refuters/graph_refuter.py:59](https://file+.vscode-resource.vscode-cdn.net/home/**/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/causal_refuters/graph_refuter.py:59), in GraphRefuter.conditional_mutual_information(self, x, y, z)
     58 def conditional_mutual_information(self, x=None, y=None, z=None):
---> 59     cmi_val = conditional_MI(data=self._data, x=x, y=y, z=list(z))
     60     key = (x, y) + (z,)
     61     if cmi_val <= 0.05:

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/utils/cit.py:143](https://file+.vscode-resource.vscode-cdn.net/home/**/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/dowhy/utils/cit.py:143), in conditional_MI(data, x, y, z)
    133 def conditional_MI(data=None, x=None, y=None, z=None):
    134     """
    135     Method to return conditional mutual information between X and Y given Z
    136     I(X, Y | Z) = H(X|Z) - H(X|Y,Z)
   (...)
    141     :returns : conditional mutual information between X and Y given Z
    142     """
--> 143     X = data[list(x)].astype(int)
    144     Y = data[list(y)].astype(int)
    145     t = list(z)

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/pandas/core/frame.py:3813](https://file+.vscode-resource.vscode-cdn.net/home/andrew/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/pandas/core/frame.py:3813), in DataFrame.__getitem__(self, key)
   3811     if is_iterator(key):
   3812         key = list(key)
-> 3813     indexer = self.columns._get_indexer_strict(key, "columns")[1]
   3815 # take() does not accept boolean indexers
   3816 if getattr(indexer, "dtype", None) == bool:

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/pandas/core/indexes/base.py:6070](https://file+.vscode-resource.vscode-cdn.net/home/**/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/pandas/core/indexes/base.py:6070), in Index._get_indexer_strict(self, key, axis_name)
   6067 else:
   6068     keyarr, indexer, new_indexer = self._reindex_non_unique(keyarr)
-> 6070 self._raise_if_missing(keyarr, indexer, axis_name)
   6072 keyarr = self.take(indexer)
   6073 if isinstance(key, Index):
   6074     # GH 42790 - Preserve name from an Index

File [~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/pandas/core/indexes/base.py:6130](https://file+.vscode-resource.vscode-cdn.net/home/**/intell/repos/dev/causal/ml/notebooks/~/.cache/pypoetry/virtualenvs/causal-JHqzxB11-py3.8/lib/python3.8/site-packages/pandas/core/indexes/base.py:6130), in Index._raise_if_missing(self, key, indexer, axis_name)
   6128     if use_interval_msg:
   6129         key = list(key)
-> 6130     raise KeyError(f"None of [{key}] are in the [{axis_name}]")
   6132 not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
   6133 raise KeyError(f"{not_found} not in index")

KeyError: "None of [Index(['F', 'o', 'o'], dtype='object')] are in the [columns]"

Suggested Fix:

Graph Refuter should type check X,Y, Z and cast accordingly in refute_model or conditional_mutual_information