py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
6.89k stars 916 forks source link

Fix RegressionEstimator categorical one-hot encoding consistency bug. #1109

Closed drawlinson closed 7 months ago

drawlinson commented 7 months ago

Fix RegressionEstimator categorical one-hot encoding consistency bug by changing from pandas get_dummies() to sklearn OneHotEncoder.

Encoder objects are created during RegressionEstimator.fit() and persist until the next fit(), allowing them to be re-applied to encode new data either via additional calls to CausalModel.estimateEffect(..., fit_estimator=False, ...) or via do() operator.

In the earlier implementation, common cause, effect modifier and potentially treatment values could be inconsistently encoded between fit() and later inference - it depends on the order particular values are encountered in the new data.

To fix, a util function is created which patches sklearn OneHotEncoder to behave like pandas get_dummies, with a convenience member function of RegressionEstimator called _encode() that makes each use a one-line change.

It is also now possible to change drop_first from True (current and original default) to False, to allow inspection of all regression coefficients if desiring to interpret model behaviour.