py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
6.99k stars 923 forks source link

Effect modifiers are not propagated to estimate_effect #990

Closed kbattocchi closed 1 year ago

kbattocchi commented 1 year ago

Describe the bug The docs for estimate_effect indicate that if effect_modifiers=None (the default), then the effect modifiers from the CausalModel are used, but this doesn't appear to be the case; rather, the effect_modifiers=None case is treated as if effect_modifiers=[] were used instead.

Steps to reproduce the behavior

import dowhy
from dowhy import CausalModel
import pandas as pd
import numpy as np

arr = np.random.normal(size=(500, 5))
df = pd.DataFrame(arr, columns=['Y', 'T', 'X', 'W0', 'W1'])

model = CausalModel(
    data = df,
    treatment = 'T',
    outcome = 'Y',
    effect_modifiers = ['X'],
    common_causes = ['X', 'W0', 'W1'],
    estimand_type="nonparametric-ate"
)

estimand = model.identify_effect()
est1 = model.estimate_effect(identified_estimand=estimand,
                             method_name="backdoor.econml.dml.LinearDML",
                             method_params={"init_params": {"random_state":123},
                                            "fit_params": {}})

est2 = model.estimate_effect(identified_estimand=estimand,
                             effect_modifiers=['X'],
                             method_name="backdoor.econml.dml.LinearDML",
                             method_params={"init_params": {"random_state":123},
                                            "fit_params": {}})

est3 = model.estimate_effect(identified_estimand=estimand,
                             effect_modifiers=[],
                             method_name="backdoor.econml.dml.LinearDML",
                             method_params={"init_params": {"random_state":123},
                                            "fit_params": {}})                         

print(est1.cate_estimates, est2.cate_estimates, est3.cate_estimates)

Expected behavior The estimates from est1 and est2 should be the same (the CATE estimates conditional on the X column) while est3 should be different (the ATE, conditional on no variables).

Actual behavior The estimates from est1 and est3 are the same instead.

Version information:

kbattocchi commented 1 year ago

Interestingly, removing 'X' from common_causes results in all three estimates being the same (and appearing to be CATEs, not ATEs)

amit-sharma commented 1 year ago

Fixed via #988