py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
6.89k stars 916 forks source link

Inconsistent encoding with pandas get_dummies causes prediction and effect estimation errors #1111

Closed drawlinson closed 3 months ago

drawlinson commented 7 months ago

Describe the bug DoWhy mostly aims to model effects within an existing, finite set of data, but generalization to unseen data is an important aspect of model validation which many user (including myself) might want to do. DoWhy has 2 ways to apply an existing model to new data.

  1. CausalModel.estimate_effect(fit_estimator=False)

This calls through to CausalEstimator.estimate_effect, which seems to assume the model has already been fitted or trained to a dataset and does not retrain it.

  1. CausalModel.do(X, ..., fit_estimator=<any value>)

Although this function takes a fit_estimator argument, it always assumes the model is fitted to the entire dataset before evaluating the effect on the provided new data, given treatment is X. Why does new data matter?

i. DoWhy uses Pandas' get_dummies to automatically encode any dataframe column with non-numerical datatypes:

https://pandas.pydata.org/docs/reference/api/pandas.get_dummies.html

ii. get_dummies output depends on the order in which specific values are encountered in the data.

iii. The order in which specific values are encountered in the data is not guaranteed on any dataset other than the original

iv. get_dummies will therefore return inconsistent encodings when the rows of the data are shuffled or simply different (unseen / generalization dataset). This doesn't necessarily cause a crash; but it can do if the number of unique values in a categorical column differs between training and validation data. More commonly, the encodings are simply different, leading to incorrect outputs from the model.

Steps to reproduce the behavior For example if a column C in data contains strings X and Y:

C = [X, Y, X, X, Y, ...] ... then initially these values might be encoded as:

X: 0,1 Y: 1,0

If they are presented in a different order:

C = [Y, Y, X, X, Y, ...]

... the same strings would be encoded as:

X: 1,0 Y: 0,1

I notice this issue has come up before, because in regression_estimator.py interventional_outcomes function, lines 203-221 the entire original treatment column is prefixed to the new treatment column to ensure it is encoded consistently. The data is then removed before being passed to _build_features(). However, this only fixes the treatment column; the other columns (such as common cause and effect modifier variables) do not have any mechanism to ensure consistent encoding.

Expected behavior Categorical data types should always be encoded consistently to achieve valid inference.

Version information:

Additional context Consistent encoding of shuffled or different data is not possible using get_dummies(), because it does not allow an encoder to be re-applied to additional data.

However, SciKit-Learn does have a one-hot encoder whose behaviour is very similar to get_dummies() and does allow the encoder to be re-applied to new data:

https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html

DoWhy already has a dependency on scikit learn, so no additional packages are required.

Since encoding depends on a specific encoder object, the lifecycle of this object must be considered. I suggest that encoders are a property of CausalEstimator objects which are re-created during CausalEstimator.fit() and persist until the next fit(). This means that estimation of effects after fit() would reuse existing encoder[s], ensuring encoding is consistent.

I've made a PR https://github.com/py-why/dowhy/pull/1112 which fixes all RegressionEstimator subclasses as described above.

Note as a side effect of this fix, it is easy to create a CausalEstimator.predict() method which allows model inference on arbitrary data; the do() method is then simplified to just assignment of the treatment value and a call to predict():

def _do(self, data_df, treatment_val):
  interventional_outcomes = self.interventional_outcomes(data_df, treatment_val)

def interventional_outcomes(self, data_df, treatment_val):
  data_df[treatment] = treatment_val
  return self.predict(data_df)

def predict(self, data_df):
  features = self._build_features(data_df)
  return self.predict_fn(data, self.model, features)

In addition the get_dummies calls can be replaced with a new util function which approximates the original get_dummies interface, e.g.

OLD self._observed_common_causes = pd.get_dummies(self._observed_common_causes, drop_first=True)

NEW self._observed_common_causes = self._encode(self._observed_common_causes, "observed_common_causes")

The same fix could be applied to all uses of get_dummies in DoWhy (there are 13 remaining in other CausalEstimators) but since this PR is already complex and all tests pass, I thought it was worth getting some thoughts on the fix first...

drawlinson commented 7 months ago

PR here https://github.com/py-why/dowhy/pull/1112

drawlinson commented 6 months ago

I think this is resolve with merge of PR 1112 . However, its also an option to remove use of get_dummies entirely if this is considered desirable, using the same util function to replace each occurrence. I can make another PR to complete that process. Keen for thoughts on that... there would be some simplification by pushing encoders down to the base CausalEstimator and things would be more consistent between estimators.

amit-sharma commented 6 months ago

That's a great idea @drawlinson . Yeah, it is better to remove use of get_dummies. Look forward to your PR

drawlinson commented 6 months ago

I'll get onto that ASAP!

drawlinson commented 5 months ago

@amit-sharma PR now available to complete the job... https://github.com/py-why/dowhy/pull/1135

drawlinson commented 3 months ago

Resolved with PR #1135