py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
6.88k stars 915 forks source link

falsify_graph #1216

Open nicolavizioli opened 3 days ago

nicolavizioli commented 3 days ago

Hi all, I'm tryig to use the SCM module, in prticular falsify_graph.

The code (i'm using streamlit) is the following:

auto_assignment_summary = gcm.auto.assign_causal_mechanisms( scm, data_pd, override_models=True, quality=gcm.auto.AssignmentQuality.GOOD ) gcm.fit(scm, data_pd)

if st.session_state["choice"] == "lingam": #causal graph chososen=lingam result = falsify_graph((st.session_state["lingam_graph_nx"]), data_pd, plot_histogram=False, suggestions=True, n_permutations=100) st.text(result) else: #causal graph chososen=pc result = falsify_graph((st.session_state["pc_graph_nx"]), data_pd, plot_histogram=False, suggestions=True, n_permutations=100) st.text(result) where "lingam_graph_nx" or "pc_graph_nx" are causal nx.graph, data_pd is a pandas df.

Expected behavior after fitting gcm I get the following error:

AssertionError: 0 must be list, set or str. Got <class 'int'> instead!

Traceback:

File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\streamlit\runtime\scriptrunner\script_runner.py", line 600, in _runscript exec(code, module.dict) File "C:\Users\11612880\Desktop\Causal-Discovery-PoC\pages\9📊_ SCM.py", line 144, in result = falsify_graph((st.session_state["pc_graph_nx"]), data_pd, plot_histogram=False, suggestions=True, n_permutations=100) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 618, in falsify_graph summary_given = run_validations( File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 412, in run_validations m_summary = m(causal_graph=causal_graph) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 138, in validate_lmc if not (node, non_desc, parents) in p_values_memory: File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 89, in contains X, Y = (_to_frozenset(i) for i in item[:2]) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 89, in X, Y = (_to_frozenset(i) for i in item[:2]) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 979, in _to_frozenset assert ( Version information:

Additional context Add any other context about the problem here.

bloebp commented 3 days ago

Might be an issue with the variable names in the dataframe. Can you try the following after loading/creating the data: data_pd.columns = [str(col) for col in data_pd.columns]

nicolavizioli commented 2 days ago

Thnks, I tried but now I had as error: AssertionError: 2 must be list, set or str. Got <class 'int'> instead!

the type of variable names in the dataframe is: Index(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], dtype='object') <class 'pandas.core.indexes.base.Index'>

so I don't know if this is the problem, many tks

bloebp commented 2 days ago

Ok, do you think you can provide a small example snippet to reproduce this (you probably don't need the streamlit context here)? I can take a closer look then.

nicolavizioli commented 2 days ago

Ok , thanks

I had a nx.graph (the nodes are int) I compute scm:

scm = gcm.StructuralCausalModel(st.session_state["pc_graph_nx"])

then for assigning causal mechanism I run:

data_pd = pd.DataFrame(data_agg) #dataframe of my data data_pd.columns = [i for i in range(len(data_pd.columns))] #define the columns as integer data_pd = check_and_convert_categorical(data_pd, threshold=2) #use a function to get categorical variables auto_assignment_summary = gcm.auto.assign_causal_mechanisms( scm, data_pd, override_models=True, quality=gcm.auto.AssignmentQuality.GOOD ) gcm.fit(scm, data_pd) st.text(auto_assignment_summary) #until now no problem

def convert_nodes_to_str(nx_graph): #convert nodes in string

mapping = {node: str(node) for node in nx_graph.nodes}

nx_graph = nx.relabel_nodes(nx_graph, mapping)
return nx_graph

data_pd.columns = [str(col) for col in data_pd.columns] #as suggested

and finally the code (in streamlit):

  pc_graph_nx=convert_nodes_to_str(st.session_state["pc_graph_nx"])
  print(pc_graph_nx.nodes)
  result = falsify_graph((st.session_state["pc_graph_nx"]), data_pd, plot_histogram=False, suggestions=True, n_permutations=100)
  st.text(result)

with the error: AssertionError: 1 must be list, set or str. Got <class 'int'> instead! Traceback: File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\streamlit\runtime\scriptrunner\script_runner.py", line 600, in _runscript exec(code, module.dict) File "C:\Users\11612880\Desktop\Causal-Discovery-PoC\pages\9📊_ SCM.py", line 167, in result = falsify_graph((st.session_state["pc_graph_nx"]), data_pd, plot_histogram=False, suggestions=True, n_permutations=100) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 618, in falsify_graph summary_given = run_validations( File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 412, in run_validations m_summary = m(causal_graph=causal_graph) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 138, in validate_lmc if not (node, non_desc, parents) in p_values_memory: File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 89, in contains X, Y = (_to_frozenset(i) for i in item[:2]) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 89, in X, Y = (_to_frozenset(i) for i in item[:2]) File "C:\Users\11612880\Anaconda3\envs\causal\lib\site-packages\dowhy\gcm\falsify.py", line 979, in _to_frozenset assert (

any idea? thanks

bloebp commented 2 days ago

The column conversion needs to happen before calling the assignment and fit, otherwise the variables are expected to be integers, but are strings. I had similar issues before due to integer column names. So, maybe just change:

data_pd.columns = [i for i in range(len(data_pd.columns))] #define the columns as integer

to

data_pd.columns = [str(i) for i in range(len(data_pd.columns))]

But you also need to make sure that the node names in the graph are strings of integers instead of raw integers themselves.

nicolavizioli commented 2 days ago

Now it works, thanks very appreciated