py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
6.91k stars 919 forks source link

plotting the graph in dowhy_example_effect_of_memberrewards_program.ipynb #493

Open zahs123 opened 2 years ago

zahs123 commented 2 years ago

Hi,

I tried to replicate the example in dowhy_example_effect_of_memberrewards_program.ipynb however i don't get the same graph plotted as in the example. I've copied the code from this notebook and this was my output: image

(code below)

# Creating some simulated data for our example
import pandas as pd
import numpy as np
num_users = 10000
num_months = 12

signup_months = np.random.choice(np.arange(1, num_months), num_users) * np.random.randint(0,2, size=num_users) # signup_months == 0 means customer did not sign up
df = pd.DataFrame({
    'user_id': np.repeat(np.arange(num_users), num_months),
    'signup_month': np.repeat(signup_months, num_months), # signup month == 0 means customer did not sign up
    'month': np.tile(np.arange(1, num_months+1), num_users), # months are from 1 to 12
    'spend': np.random.poisson(500, num_users*num_months) #np.random.beta(a=2, b=5, size=num_users * num_months)*1000 # centered at 500
})
# A customer is in the treatment group if and only if they signed up
df["treatment"] = df["signup_month"]>0
# Simulating an effect of month (monotonically decreasing--customers buy less later in the year)
df["spend"] = df["spend"] - df["month"]*10
# Simulating a simple treatment effect of 100
after_signup = (df["signup_month"] < df["month"]) & (df["treatment"])
df.loc[after_signup,"spend"] = df[after_signup]["spend"] + 100
df

import dowhy

# Setting the signup month (for ease of analysis)
i = 3

causal_graph = """digraph {
treatment[label="Program Signup in month i"];
pre_spends;
post_spends;
Z->treatment;
pre_spends -> treatment;
treatment->post_spends;
signup_month->post_spends;
signup_month->treatment;
}"""

# Post-process the data based on the graph and the month of the treatment (signup)
# For each customer, determine their average monthly spend before and after month i
df_i_signupmonth = (
    df[df.signup_month.isin([0, i])]
    .groupby(["user_id", "signup_month", "treatment"])
    .apply(
        lambda x: pd.Series(
            {
                "pre_spends": x.loc[x.month < i, "spend"].mean(),
                "post_spends": x.loc[x.month > i, "spend"].mean(),
            }
        )
    )
    .reset_index()
)
print(df_i_signupmonth)
model = dowhy.CausalModel(data=df_i_signupmonth,
                     graph=causal_graph.replace("\n", " "),
                     treatment="treatment",
                     outcome="post_spends")
model.view_model()
from IPython.display import Image, display
display(Image(filename="causal_model.png"))
amit-sharma commented 2 years ago

The graph structure is the same compared to the member rewards notebook

The rendering is different since dowhy defaults to matplotlib if pygraphviz is not found. Can you try installing pygraphviz pip install pygraphviz (which requires installing graphviz) and then rerun the example? Instructions for installing graphviz on your system are on the README.

zahs123 commented 2 years ago

Thanks for this struggling to install graphviz and pygraphviz the way specified, even installing my own way doesn't work.

zahs123 commented 2 years ago

The graph structure is the same compared to the member rewards notebook

The rendering is different since dowhy defaults to matplotlib if pygraphviz is not found. Can you try installing pygraphviz pip install pygraphviz (which requires installing graphviz) and then rerun the example? Instructions for installing graphviz on your system are on the README.

What is 'Z' in the graph? Is it meant to represent confounding variable?

amit-sharma commented 2 years ago

Z represents an instrumental variable, as shown by the output of the identify_effect method. But it is not observed in the data, so we do not use it.