py-why / EconML

ALICE (Automated Learning and Intelligence for Causation and Economics) is a Microsoft Research project aimed at applying Artificial Intelligence concepts to economic decision making. One of its goals is to build a toolkit that combines state-of-the-art machine learning techniques with econometrics in order to bring automation to complex causal inference problems. To date, the ALICE Python SDK (econml) implements orthogonal machine learning algorithms such as the double machine learning work of Chernozhukov et al. This toolkit is designed to measure the causal effect of some treatment variable(s) t on an outcome variable y, controlling for a set of features x.
https://www.microsoft.com/en-us/research/project/alice/
Other
3.75k stars 711 forks source link

Obtaining [T0][T1] or [Y0][T0] in causalforest shap values #708

Closed fobembe closed 1 year ago

fobembe commented 1 year ago

I am trying to implement the shap plots to obtain the relevant factors for my predictive analytics. I have done this with causalml without any problem, but I have been struggling to get this done with econml. Does anyone have a clue on what I am doing wrongly in the following code?

I'll appreciate any help. Thanks.

image

vsyrgkanis commented 1 year ago

Most prob you have a binary treatment and you need to use the value of the treatment itself as keyword. You can print the keys of the dictionary shap_values[Y0] to see what they right second keyword is instead of T0

vsyrgkanis commented 1 year ago

Check cells 6,7,8 here https://github.com/microsoft/EconML/blob/main/notebooks/Interpretability%20with%20SHAP.ipynb

i think you need to pass T0_1 if your treatments where taking values 0/1 and you are using the drlearner, or metalearners

fobembe commented 1 year ago

Thanks a lot for your reply. I am using causal forest and I used the following but still have error image

fobembe commented 1 year ago

image image

kbattocchi commented 1 year ago

It would be very helpful to get a complete, self-contained repro of your issue (e.g. in the last screenshot above it looks like the last three outputs result from evaluations 30, 37, and 49 in the notebook, but without knowing what's happening in between those calls it's very hard to diagnose the problem).

If you have continuous treatments, then I believe that the dictionary returned by est.shap_values(X_train) should always have a single key in the outer dictionary for each outcome, and indexing into this dictionary should give you an inner dictionary with one key per treatment; the names of those keys will depend on whether the outcomes and treatments already have names (for example if you passed in a pandas Series during fitting we will take the name from that) or if we just assign default names.

What are the dimensions of your Y, T, and X arrays? Also, can you print the output of pip list just to make sure that we understand what environment you're running in?

fobembe commented 1 year ago

I really appreciate this. My treatment is binary 0/1 and my outcome is continuous (Sales). I reran the whole code to provide a complete picture of what is happening. Thanks again for your assistance.

image image image image

fobembe commented 1 year ago

Pip list package Version


absl-py 1.3.0 aeppl 0.0.33 aesara 2.7.9 aiohttp 3.8.3 aiosignal 1.3.1 alabaster 0.7.12 albumentations 1.2.1 altair 4.2.0 appdirs 1.4.4 arviz 0.12.1 astor 0.8.1 astropy 4.3.1 astunparse 1.6.3 async-timeout 4.0.2 atari-py 0.2.9 atomicwrites 1.4.1 attrs 22.1.0 audioread 3.0.0 autograd 1.5 Babel 2.11.0 backcall 0.2.0 beautifulsoup4 4.6.3 bleach 5.0.1 blis 0.7.9 bokeh 2.3.3 branca 0.6.0 bs4 0.0.1 CacheControl 0.12.11 cachetools 5.2.0 catalogue 2.0.8 certifi 2022.9.24 cffi 1.15.1 cftime 1.6.2 chardet 3.0.4 charset-normalizer 2.1.1 click 7.1.2 clikit 0.6.2 cloudpickle 1.5.0 cmake 3.22.6 cmdstanpy 1.0.8 colorcet 3.0.1 colorlover 0.3.0 community 1.0.0b1 confection 0.0.3 cons 0.4.5 contextlib2 0.5.5 convertdate 2.4.0 crashtest 0.3.1 crcmod 1.7 cufflinks 0.17.3 cvxopt 1.3.0 cvxpy 1.2.2 cycler 0.11.0 cymem 2.0.7 Cython 0.29.32 daft 0.0.4 dask 2022.2.1 datascience 0.17.5 db-dtypes 1.0.4 debugpy 1.0.0 decorator 4.4.2 defusedxml 0.7.1 descartes 1.1.0 dill 0.3.6 distributed 2022.2.1 dlib 19.24.0 dm-tree 0.1.7 dnspython 2.2.1 docutils 0.17.1 dopamine-rl 1.0.5 earthengine-api 0.1.334 easydict 1.10 econml 0.14.0 ecos 2.0.10 editdistance 0.5.3 en-core-web-sm 3.4.1 entrypoints 0.4 ephem 4.1.3 et-xmlfile 1.1.0 etils 0.9.0 etuples 0.3.8 fa2 0.3.5 fastai 2.7.10 fastcore 1.5.27 fastdownload 0.0.7 fastdtw 0.3.4 fastjsonschema 2.16.2 fastprogress 1.0.3 fastrlock 0.8.1 feather-format 0.4.1 filelock 3.8.0 firebase-admin 5.3.0 fix-yahoo-finance 0.0.22 Flask 1.1.4 flatbuffers 1.12 folium 0.12.1.post1 frozenlist 1.3.3 fsspec 2022.11.0 future 0.16.0 gast 0.4.0 GDAL 2.2.2 gdown 4.4.0 gensim 3.6.0 geographiclib 1.52 geopy 1.17.0 gin-config 0.5.0 glob2 0.7 google 2.0.3 google-api-core 2.8.2 google-api-python-client 1.12.11 google-auth 2.15.0 google-auth-httplib2 0.0.4 google-auth-oauthlib 0.4.6 google-cloud-bigquery 3.3.6 google-cloud-bigquery-storage 2.16.2 google-cloud-core 2.3.2 google-cloud-datastore 2.9.0 google-cloud-firestore 2.7.2 google-cloud-language 2.6.1 google-cloud-storage 2.5.0 google-cloud-translate 3.8.4 google-colab 1.0.0 google-crc32c 1.5.0 google-pasta 0.2.0 google-resumable-media 2.4.0 googleapis-common-protos 1.57.0 googledrivedownloader 0.4 graphviz 0.10.1 greenlet 2.0.1 grpcio 1.51.1 grpcio-status 1.48.2 gspread 3.4.2 gspread-dataframe 3.0.8 gym 0.25.2 gym-notices 0.0.8 h5py 3.1.0 HeapDict 1.0.1 hijri-converter 2.2.4 holidays 0.17.2 holoviews 1.14.9 html5lib 1.0.1 httpimport 0.5.18 httplib2 0.17.4 httpstan 4.6.1 humanize 0.5.1 hyperopt 0.1.2 idna 2.10 imageio 2.9.0 imagesize 1.4.1 imbalanced-learn 0.8.1 imblearn 0.0 imgaug 0.4.0 importlib-metadata 4.13.0 importlib-resources 5.10.0 imutils 0.5.4 inflect 2.1.0 intel-openmp 2022.2.1 intervaltree 2.1.0 ipykernel 5.3.4 ipython 7.9.0 ipython-genutils 0.2.0 ipython-sql 0.3.9 ipywidgets 7.7.1 itsdangerous 1.1.0 jax 0.3.25 jaxlib 0.3.25+cuda11.cudnn805 jieba 0.42.1 Jinja2 2.11.3 joblib 1.2.0 jpeg4py 0.1.4 jsonschema 4.3.3 jupyter-client 6.1.12 jupyter-console 6.1.0 jupyter-core 5.1.0 jupyterlab-widgets 3.0.3 kaggle 1.5.12 kapre 0.3.7 keras 2.9.0 Keras-Preprocessing 1.1.2 keras-vis 0.4.1 kiwisolver 1.4.4 korean-lunar-calendar 0.3.1 langcodes 3.3.0 libclang 14.0.6 librosa 0.8.1 lightgbm 2.2.3 llvmlite 0.39.1 lmdb 0.99 locket 1.0.0 logical-unification 0.4.5 LunarCalendar 0.0.9 lxml 4.9.1 Markdown 3.4.1 MarkupSafe 2.0.1 marshmallow 3.19.0 matplotlib 3.2.2 matplotlib-venn 0.11.7 miniKanren 1.0.3 missingno 0.5.1 mistune 0.8.4 mizani 0.7.3 mkl 2019.0 mlxtend 0.14.0 more-itertools 9.0.0 moviepy 0.2.3.5 mpmath 1.2.1 msgpack 1.0.4 multidict 6.0.3 multipledispatch 0.6.0 multitasking 0.0.11 murmurhash 1.0.9 music21 5.5.0 natsort 5.5.0 nbconvert 5.6.1 nbformat 5.7.0 netCDF4 1.6.2 networkx 2.8.8 nibabel 3.0.2 nltk 3.7 notebook 5.7.16 numba 0.56.4 numexpr 2.8.4 numpy 1.21.6 oauth2client 4.1.3 oauthlib 3.2.2 okgrade 0.4.3 opencv-contrib-python 4.6.0.66 opencv-python 4.6.0.66 opencv-python-headless 4.6.0.66 openpyxl 3.0.10 opt-einsum 3.3.0 osqp 0.6.2.post0 packaging 21.3 palettable 3.3.0 pandas 1.3.5 pandas-datareader 0.9.0 pandas-gbq 0.17.9 pandas-profiling 1.4.1 pandocfilters 1.5.0 panel 0.12.1 param 1.12.2 parso 0.8.3 partd 1.3.0 pastel 0.2.1 pathlib 1.0.1 pathy 0.10.0 patsy 0.5.3 pep517 0.13.0 pexpect 4.8.0 pickleshare 0.7.5 Pillow 7.1.2 pip 21.1.3 pip-tools 6.2.0 platformdirs 2.5.4 plotly 5.5.0 plotnine 0.8.0 pluggy 0.7.1 pooch 1.6.0 portpicker 1.3.9 prefetch-generator 1.0.3 preshed 3.0.8 prettytable 3.5.0 progressbar2 3.38.0 prometheus-client 0.15.0 promise 2.3 prompt-toolkit 2.0.10 prophet 1.1.1 proto-plus 1.22.1 protobuf 3.19.6 psutil 5.4.8 psycopg2 2.9.5 ptyprocess 0.7.0 py 1.11.0 pyarrow 9.0.0 pyasn1 0.4.8 pyasn1-modules 0.2.8 pycocotools 2.0.6 pycparser 2.21 pyct 0.4.8 pydantic 1.10.2 pydata-google-auth 1.4.0 pydot 1.3.0 pydot-ng 2.0.0 pydotplus 2.0.2 PyDrive 1.3.1 pyemd 0.5.1 pyerfa 2.0.0.1 Pygments 2.6.1 pygobject 3.26.1 pylev 1.4.0 pymc 4.1.4 PyMeeus 0.5.11 pymongo 4.3.3 pymystem3 0.2.0 PyOpenGL 3.1.6 pyparsing 3.0.9 pyrsistent 0.19.2 pysimdjson 3.2.0 pysndfile 1.3.8 PySocks 1.7.1 pystan 3.3.0 pytest 3.6.4 python-apt 0.0.0 python-dateutil 2.8.2 python-louvain 0.16 python-slugify 7.0.0 python-utils 3.4.5 pytz 2022.6 pyviz-comms 2.2.1 PyWavelets 1.4.1 PyYAML 6.0 pyzmq 23.2.1 qdldl 0.1.5.post2 qudida 0.0.4 regex 2022.6.2 requests 2.23.0 requests-oauthlib 1.3.1 resampy 0.4.2 rpy2 3.5.5 rsa 4.9 scikit-image 0.18.3 scikit-learn 1.0.2 scipy 1.7.3 screen-resolution-extra 0.0.0 scs 3.2.2 seaborn 0.11.2 Send2Trash 1.8.0 setuptools 57.4.0 setuptools-git 1.2 shap 0.40.0 Shapely 1.8.5.post1 six 1.15.0 sklearn-pandas 1.8.0 slicer 0.0.7 smart-open 5.2.1 snowballstemmer 2.2.0 sortedcontainers 2.4.0 soundfile 0.11.0 spacy 3.4.3 spacy-legacy 3.0.10 spacy-loggers 1.0.3 sparse 0.13.0 Sphinx 1.8.6 sphinxcontrib-serializinghtml 1.1.5 sphinxcontrib-websupport 1.2.4 SQLAlchemy 1.4.44 sqlparse 0.4.3 srsly 2.4.5 statsmodels 0.12.2 sympy 1.7.1 tables 3.7.0 tabulate 0.8.10 tblib 1.7.0 tenacity 8.1.0 tensorboard 2.9.1 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.1 tensorflow 2.9.2 tensorflow-datasets 4.6.0 tensorflow-estimator 2.9.0 tensorflow-gcs-config 2.9.1 tensorflow-hub 0.12.0 tensorflow-io-gcs-filesystem 0.28.0 tensorflow-metadata 1.11.0 tensorflow-probability 0.17.0 termcolor 2.1.1 terminado 0.13.3 testpath 0.6.0 text-unidecode 1.3 textblob 0.15.3 thinc 8.1.5 threadpoolctl 3.1.0 tifffile 2022.10.10 toml 0.10.2 tomli 2.0.1 toolz 0.12.0 torch 1.13.0+cu116 torchaudio 0.13.0+cu116 torchsummary 1.5.1 torchtext 0.14.0 torchvision 0.14.0+cu116 tornado 6.0.4 tqdm 4.64.1 traitlets 5.6.0 tweepy 3.10.0 typeguard 2.7.1 typer 0.7.0 typing-extensions 4.4.0 tzlocal 1.5.1 uritemplate 3.0.1 urllib3 1.24.3 vega-datasets 0.9.0 wasabi 0.10.1 wcwidth 0.2.5 webargs 8.2.0 webencodings 0.5.1 Werkzeug 1.0.1 wheel 0.38.4 widgetsnbextension 3.6.1 wordcloud 1.8.2.2 wrapt 1.14.1 xarray 0.20.2 xarray-einstats 0.3.0 xgboost 0.90 xkit 0.0.0 xlrd 1.2.0 xlwt 1.3.0 yarl 1.8.2 yellowbrick 1.5 zict 2.2.0 zipp 3.11.0

kbattocchi commented 1 year ago

Thanks, that helps a ton. So part of the confusion is an unfortunate consequence of the fact that we're returning a defaultdict instance as the outer level of the dictionary, which lets you index into it with any key (even one not originally in the dictionary, in which case an empty entry with that key is created). This is why "Y0" is showing up as an entry - it's created when you try to access shap_values["Y0"]["T0"], but it's created as an empty dictionary so there are no keys in it and trying to get the entry for "T0" fails.

Instead, you need to access shap_values["Sales"] at the first level. Check what the keys are there, the second level is probably either "T0" or "treatment0", but I'm not sure offhand.

Thanks for raising this issue - I think that from our side we should move away from using a defaultdict as the outer dictionary since it makes it much harder to understand what's going wrong.

fobembe commented 1 year ago

I am so grateful for graciously responding to my issue. I have tried shap_values['Sales'] and got 'treatment_01', I tried shap_values['treatment0'], shap_values['T0'], and shap_values['Y0'] at the second level, but I got an empty dictionary. here is the code

image

kbattocchi commented 1 year ago

In that case, tryshap_values['Sales']['treatment_1']

fobembe commented 1 year ago

We are finally there. Thanks a million times. I am so grateful.

image

fobembe commented 1 year ago

The only problem I see in the above plot is that the price and income variables are no longer displayed. These factors were displayed as the most important factors when I used causalml to plot it. Do you have any idea why this is not showing here?

kbattocchi commented 1 year ago

I'd guess that they are bunched together with two other features in the last entry; by default the beeswarm plot will only show the top 10 features but you can adjust this by passing a higher value for the max_display argument when you call beeswarm (see https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html#A-simple-beeswarm-summary-plot).

fobembe commented 1 year ago

Thanks a lot. I really appreciate all the efforts.

image

kbattocchi commented 1 year ago

Happy to help.

That chart doesn't look particularly useful in its present form (all SHAP values appear to be 0) - you might try initializing the forest with just CausalForestDML(model_y=RandomForestRegressor(random_state=42), discrete_treatment=True) and then calling tune on it with your data, and then calling fit, to see if better hyperparameters improve the quality of the CATE estimate, and whether that has any impact on the SHAP values.

fobembe commented 1 year ago

I will do just that.

Thanks.

fobembe commented 1 year ago

Here is my final output. You're wonderfully appreciated. Please keep up the good work.

image