Closed fobembe closed 1 year ago
Most prob you have a binary treatment and you need to use the value of the treatment itself as keyword. You can print the keys of the dictionary shap_values[Y0] to see what they right second keyword is instead of T0
Check cells 6,7,8 here https://github.com/microsoft/EconML/blob/main/notebooks/Interpretability%20with%20SHAP.ipynb
i think you need to pass T0_1 if your treatments where taking values 0/1 and you are using the drlearner, or metalearners
Thanks a lot for your reply. I am using causal forest and I used the following but still have error
It would be very helpful to get a complete, self-contained repro of your issue (e.g. in the last screenshot above it looks like the last three outputs result from evaluations 30, 37, and 49 in the notebook, but without knowing what's happening in between those calls it's very hard to diagnose the problem).
If you have continuous treatments, then I believe that the dictionary returned by est.shap_values(X_train)
should always have a single key in the outer dictionary for each outcome, and indexing into this dictionary should give you an inner dictionary with one key per treatment; the names of those keys will depend on whether the outcomes and treatments already have names (for example if you passed in a pandas Series during fitting we will take the name from that) or if we just assign default names.
What are the dimensions of your Y, T, and X arrays? Also, can you print the output of pip list
just to make sure that we understand what environment you're running in?
I really appreciate this. My treatment is binary 0/1 and my outcome is continuous (Sales). I reran the whole code to provide a complete picture of what is happening. Thanks again for your assistance.
Pip list package Version
absl-py 1.3.0 aeppl 0.0.33 aesara 2.7.9 aiohttp 3.8.3 aiosignal 1.3.1 alabaster 0.7.12 albumentations 1.2.1 altair 4.2.0 appdirs 1.4.4 arviz 0.12.1 astor 0.8.1 astropy 4.3.1 astunparse 1.6.3 async-timeout 4.0.2 atari-py 0.2.9 atomicwrites 1.4.1 attrs 22.1.0 audioread 3.0.0 autograd 1.5 Babel 2.11.0 backcall 0.2.0 beautifulsoup4 4.6.3 bleach 5.0.1 blis 0.7.9 bokeh 2.3.3 branca 0.6.0 bs4 0.0.1 CacheControl 0.12.11 cachetools 5.2.0 catalogue 2.0.8 certifi 2022.9.24 cffi 1.15.1 cftime 1.6.2 chardet 3.0.4 charset-normalizer 2.1.1 click 7.1.2 clikit 0.6.2 cloudpickle 1.5.0 cmake 3.22.6 cmdstanpy 1.0.8 colorcet 3.0.1 colorlover 0.3.0 community 1.0.0b1 confection 0.0.3 cons 0.4.5 contextlib2 0.5.5 convertdate 2.4.0 crashtest 0.3.1 crcmod 1.7 cufflinks 0.17.3 cvxopt 1.3.0 cvxpy 1.2.2 cycler 0.11.0 cymem 2.0.7 Cython 0.29.32 daft 0.0.4 dask 2022.2.1 datascience 0.17.5 db-dtypes 1.0.4 debugpy 1.0.0 decorator 4.4.2 defusedxml 0.7.1 descartes 1.1.0 dill 0.3.6 distributed 2022.2.1 dlib 19.24.0 dm-tree 0.1.7 dnspython 2.2.1 docutils 0.17.1 dopamine-rl 1.0.5 earthengine-api 0.1.334 easydict 1.10 econml 0.14.0 ecos 2.0.10 editdistance 0.5.3 en-core-web-sm 3.4.1 entrypoints 0.4 ephem 4.1.3 et-xmlfile 1.1.0 etils 0.9.0 etuples 0.3.8 fa2 0.3.5 fastai 2.7.10 fastcore 1.5.27 fastdownload 0.0.7 fastdtw 0.3.4 fastjsonschema 2.16.2 fastprogress 1.0.3 fastrlock 0.8.1 feather-format 0.4.1 filelock 3.8.0 firebase-admin 5.3.0 fix-yahoo-finance 0.0.22 Flask 1.1.4 flatbuffers 1.12 folium 0.12.1.post1 frozenlist 1.3.3 fsspec 2022.11.0 future 0.16.0 gast 0.4.0 GDAL 2.2.2 gdown 4.4.0 gensim 3.6.0 geographiclib 1.52 geopy 1.17.0 gin-config 0.5.0 glob2 0.7 google 2.0.3 google-api-core 2.8.2 google-api-python-client 1.12.11 google-auth 2.15.0 google-auth-httplib2 0.0.4 google-auth-oauthlib 0.4.6 google-cloud-bigquery 3.3.6 google-cloud-bigquery-storage 2.16.2 google-cloud-core 2.3.2 google-cloud-datastore 2.9.0 google-cloud-firestore 2.7.2 google-cloud-language 2.6.1 google-cloud-storage 2.5.0 google-cloud-translate 3.8.4 google-colab 1.0.0 google-crc32c 1.5.0 google-pasta 0.2.0 google-resumable-media 2.4.0 googleapis-common-protos 1.57.0 googledrivedownloader 0.4 graphviz 0.10.1 greenlet 2.0.1 grpcio 1.51.1 grpcio-status 1.48.2 gspread 3.4.2 gspread-dataframe 3.0.8 gym 0.25.2 gym-notices 0.0.8 h5py 3.1.0 HeapDict 1.0.1 hijri-converter 2.2.4 holidays 0.17.2 holoviews 1.14.9 html5lib 1.0.1 httpimport 0.5.18 httplib2 0.17.4 httpstan 4.6.1 humanize 0.5.1 hyperopt 0.1.2 idna 2.10 imageio 2.9.0 imagesize 1.4.1 imbalanced-learn 0.8.1 imblearn 0.0 imgaug 0.4.0 importlib-metadata 4.13.0 importlib-resources 5.10.0 imutils 0.5.4 inflect 2.1.0 intel-openmp 2022.2.1 intervaltree 2.1.0 ipykernel 5.3.4 ipython 7.9.0 ipython-genutils 0.2.0 ipython-sql 0.3.9 ipywidgets 7.7.1 itsdangerous 1.1.0 jax 0.3.25 jaxlib 0.3.25+cuda11.cudnn805 jieba 0.42.1 Jinja2 2.11.3 joblib 1.2.0 jpeg4py 0.1.4 jsonschema 4.3.3 jupyter-client 6.1.12 jupyter-console 6.1.0 jupyter-core 5.1.0 jupyterlab-widgets 3.0.3 kaggle 1.5.12 kapre 0.3.7 keras 2.9.0 Keras-Preprocessing 1.1.2 keras-vis 0.4.1 kiwisolver 1.4.4 korean-lunar-calendar 0.3.1 langcodes 3.3.0 libclang 14.0.6 librosa 0.8.1 lightgbm 2.2.3 llvmlite 0.39.1 lmdb 0.99 locket 1.0.0 logical-unification 0.4.5 LunarCalendar 0.0.9 lxml 4.9.1 Markdown 3.4.1 MarkupSafe 2.0.1 marshmallow 3.19.0 matplotlib 3.2.2 matplotlib-venn 0.11.7 miniKanren 1.0.3 missingno 0.5.1 mistune 0.8.4 mizani 0.7.3 mkl 2019.0 mlxtend 0.14.0 more-itertools 9.0.0 moviepy 0.2.3.5 mpmath 1.2.1 msgpack 1.0.4 multidict 6.0.3 multipledispatch 0.6.0 multitasking 0.0.11 murmurhash 1.0.9 music21 5.5.0 natsort 5.5.0 nbconvert 5.6.1 nbformat 5.7.0 netCDF4 1.6.2 networkx 2.8.8 nibabel 3.0.2 nltk 3.7 notebook 5.7.16 numba 0.56.4 numexpr 2.8.4 numpy 1.21.6 oauth2client 4.1.3 oauthlib 3.2.2 okgrade 0.4.3 opencv-contrib-python 4.6.0.66 opencv-python 4.6.0.66 opencv-python-headless 4.6.0.66 openpyxl 3.0.10 opt-einsum 3.3.0 osqp 0.6.2.post0 packaging 21.3 palettable 3.3.0 pandas 1.3.5 pandas-datareader 0.9.0 pandas-gbq 0.17.9 pandas-profiling 1.4.1 pandocfilters 1.5.0 panel 0.12.1 param 1.12.2 parso 0.8.3 partd 1.3.0 pastel 0.2.1 pathlib 1.0.1 pathy 0.10.0 patsy 0.5.3 pep517 0.13.0 pexpect 4.8.0 pickleshare 0.7.5 Pillow 7.1.2 pip 21.1.3 pip-tools 6.2.0 platformdirs 2.5.4 plotly 5.5.0 plotnine 0.8.0 pluggy 0.7.1 pooch 1.6.0 portpicker 1.3.9 prefetch-generator 1.0.3 preshed 3.0.8 prettytable 3.5.0 progressbar2 3.38.0 prometheus-client 0.15.0 promise 2.3 prompt-toolkit 2.0.10 prophet 1.1.1 proto-plus 1.22.1 protobuf 3.19.6 psutil 5.4.8 psycopg2 2.9.5 ptyprocess 0.7.0 py 1.11.0 pyarrow 9.0.0 pyasn1 0.4.8 pyasn1-modules 0.2.8 pycocotools 2.0.6 pycparser 2.21 pyct 0.4.8 pydantic 1.10.2 pydata-google-auth 1.4.0 pydot 1.3.0 pydot-ng 2.0.0 pydotplus 2.0.2 PyDrive 1.3.1 pyemd 0.5.1 pyerfa 2.0.0.1 Pygments 2.6.1 pygobject 3.26.1 pylev 1.4.0 pymc 4.1.4 PyMeeus 0.5.11 pymongo 4.3.3 pymystem3 0.2.0 PyOpenGL 3.1.6 pyparsing 3.0.9 pyrsistent 0.19.2 pysimdjson 3.2.0 pysndfile 1.3.8 PySocks 1.7.1 pystan 3.3.0 pytest 3.6.4 python-apt 0.0.0 python-dateutil 2.8.2 python-louvain 0.16 python-slugify 7.0.0 python-utils 3.4.5 pytz 2022.6 pyviz-comms 2.2.1 PyWavelets 1.4.1 PyYAML 6.0 pyzmq 23.2.1 qdldl 0.1.5.post2 qudida 0.0.4 regex 2022.6.2 requests 2.23.0 requests-oauthlib 1.3.1 resampy 0.4.2 rpy2 3.5.5 rsa 4.9 scikit-image 0.18.3 scikit-learn 1.0.2 scipy 1.7.3 screen-resolution-extra 0.0.0 scs 3.2.2 seaborn 0.11.2 Send2Trash 1.8.0 setuptools 57.4.0 setuptools-git 1.2 shap 0.40.0 Shapely 1.8.5.post1 six 1.15.0 sklearn-pandas 1.8.0 slicer 0.0.7 smart-open 5.2.1 snowballstemmer 2.2.0 sortedcontainers 2.4.0 soundfile 0.11.0 spacy 3.4.3 spacy-legacy 3.0.10 spacy-loggers 1.0.3 sparse 0.13.0 Sphinx 1.8.6 sphinxcontrib-serializinghtml 1.1.5 sphinxcontrib-websupport 1.2.4 SQLAlchemy 1.4.44 sqlparse 0.4.3 srsly 2.4.5 statsmodels 0.12.2 sympy 1.7.1 tables 3.7.0 tabulate 0.8.10 tblib 1.7.0 tenacity 8.1.0 tensorboard 2.9.1 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.1 tensorflow 2.9.2 tensorflow-datasets 4.6.0 tensorflow-estimator 2.9.0 tensorflow-gcs-config 2.9.1 tensorflow-hub 0.12.0 tensorflow-io-gcs-filesystem 0.28.0 tensorflow-metadata 1.11.0 tensorflow-probability 0.17.0 termcolor 2.1.1 terminado 0.13.3 testpath 0.6.0 text-unidecode 1.3 textblob 0.15.3 thinc 8.1.5 threadpoolctl 3.1.0 tifffile 2022.10.10 toml 0.10.2 tomli 2.0.1 toolz 0.12.0 torch 1.13.0+cu116 torchaudio 0.13.0+cu116 torchsummary 1.5.1 torchtext 0.14.0 torchvision 0.14.0+cu116 tornado 6.0.4 tqdm 4.64.1 traitlets 5.6.0 tweepy 3.10.0 typeguard 2.7.1 typer 0.7.0 typing-extensions 4.4.0 tzlocal 1.5.1 uritemplate 3.0.1 urllib3 1.24.3 vega-datasets 0.9.0 wasabi 0.10.1 wcwidth 0.2.5 webargs 8.2.0 webencodings 0.5.1 Werkzeug 1.0.1 wheel 0.38.4 widgetsnbextension 3.6.1 wordcloud 1.8.2.2 wrapt 1.14.1 xarray 0.20.2 xarray-einstats 0.3.0 xgboost 0.90 xkit 0.0.0 xlrd 1.2.0 xlwt 1.3.0 yarl 1.8.2 yellowbrick 1.5 zict 2.2.0 zipp 3.11.0
Thanks, that helps a ton. So part of the confusion is an unfortunate consequence of the fact that we're returning a defaultdict
instance as the outer level of the dictionary, which lets you index into it with any key (even one not originally in the dictionary, in which case an empty entry with that key is created). This is why "Y0"
is showing up as an entry - it's created when you try to access shap_values["Y0"]["T0"]
, but it's created as an empty dictionary so there are no keys in it and trying to get the entry for "T0"
fails.
Instead, you need to access shap_values["Sales"]
at the first level. Check what the keys are there, the second level is probably either "T0" or "treatment0", but I'm not sure offhand.
Thanks for raising this issue - I think that from our side we should move away from using a defaultdict
as the outer dictionary since it makes it much harder to understand what's going wrong.
I am so grateful for graciously responding to my issue. I have tried shap_values['Sales'] and got 'treatment_01', I tried shap_values['treatment0'], shap_values['T0'], and shap_values['Y0'] at the second level, but I got an empty dictionary. here is the code
In that case, tryshap_values['Sales']['treatment_1']
We are finally there. Thanks a million times. I am so grateful.
The only problem I see in the above plot is that the price and income variables are no longer displayed. These factors were displayed as the most important factors when I used causalml to plot it. Do you have any idea why this is not showing here?
I'd guess that they are bunched together with two other features in the last entry; by default the beeswarm plot will only show the top 10 features but you can adjust this by passing a higher value for the max_display
argument when you call beeswarm
(see https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html#A-simple-beeswarm-summary-plot).
Thanks a lot. I really appreciate all the efforts.
Happy to help.
That chart doesn't look particularly useful in its present form (all SHAP values appear to be 0) - you might try initializing the forest with just CausalForestDML(model_y=RandomForestRegressor(random_state=42), discrete_treatment=True)
and then calling tune
on it with your data, and then calling fit
, to see if better hyperparameters improve the quality of the CATE estimate, and whether that has any impact on the SHAP values.
I will do just that.
Thanks.
Here is my final output. You're wonderfully appreciated. Please keep up the good work.
I am trying to implement the shap plots to obtain the relevant factors for my predictive analytics. I have done this with causalml without any problem, but I have been struggling to get this done with econml. Does anyone have a clue on what I am doing wrongly in the following code?
I'll appreciate any help. Thanks.