py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
7.1k stars 935 forks source link

Only one CATE value is recorded from EconML. #890

Closed yiwei-ang closed 1 year ago

yiwei-ang commented 1 year ago

Describe the bug In dowhy >= 0.9, fitting dml from EconML only produces one CATE

Steps to reproduce the behavior

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import graphviz
import dowhy
from IPython.display import Image, display

from dowhy import CausalModel
import dowhy.datasets

data = dowhy.datasets.linear_dataset(beta=10,
        num_common_causes=5,
        num_samples=5000,
        treatment_is_binary=True,
        stddev_treatment_noise=10,
        num_discrete_common_causes=1)
df = data["df"]
print(f"True ATE: {data['ate']}")

model = CausalModel(data=data['df'],
                    treatment=data['treatment_name'],
                    outcome=data['outcome_name'],
                    graph=data['gml_graph'])

identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV, LogisticRegressionCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand, 
                 method_name="backdoor.econml.dml.DML",                                
                control_value = 0,
                treatment_value = 1,
                confidence_intervals=False,
                method_params={
                    "init_params":{'model_y':GradientBoostingRegressor(),
                                   'model_t': GradientBoostingRegressor(),
                                   'model_final':LassoCV(),
                                   'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
                    "fit_params":{}}
                                        )
print(dml_estimate)

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LassoCV, LogisticRegressionCV
from sklearn.ensemble import GradientBoostingRegressor
dml_estimate = model.estimate_effect(identified_estimand, 
                 method_name="backdoor.econml.dml.DML",                                
                control_value = 0,
                treatment_value = 1,
                confidence_intervals=False,
                method_params={
                    "init_params":{'model_y':GradientBoostingRegressor(),
                                   'model_t': GradientBoostingRegressor(),
                                   'model_final':LassoCV(),
                                   'featurizer':PolynomialFeatures(degree=1, include_bias=True)},
                    "fit_params":{}}
                                        )
print(dml_estimate)

## Estimate
Mean value: 10.035988054924891
Effect estimates: [[10.03598805]]

This can also include a verbatim copy of outputs, or screenshots.

Expected behavior dml_estimate.cate_estimates is expected to have 5000 rows of CATE values in dowhy < 0.9. However in dowhy >= 0.9, this only produces the mean value.

Version information:

Additional context I'm not sure if EconML should update its package to fulfil latest change in dowhy. It's because when I run est_dw = dml_estimate.estimator.estimator.dowhy I get the following warning message:

econml has not been tested with dowhy versions >= 0.9
amit-sharma commented 1 year ago

@yiwei-ang This is because no effect modifiers are specified in the dataset. If you change your dataset to include effect modifiers (using num_effect_modifiers), you will start seeing multiple CATE values. Alternatively, you can provide a graph as input and DoWhy can infer the effect modifiers.

data = dowhy.datasets.linear_dataset(beta=10,
        num_common_causes=5,
        num_effect_modifiers=1,
        num_samples=5000,
        treatment_is_binary=True,
        stddev_treatment_noise=10,
        num_discrete_common_causes=1)