BiomedSciAI / causallib

A Python package for modular causal inference analysis and model evaluations
Apache License 2.0
728 stars 97 forks source link

Issues with Categorical Data #48

Closed jgdpsingh closed 1 year ago

jgdpsingh commented 1 year ago

So I'm working on a survey data where I am trying to figure out cause and effect relationship between a person's responses to the survey questions and his/her final preference towards that product. The data is categorical entirely. On using Causal Inference 360's evaluation plots, I got the following error:

/usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric precision could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted']. warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric recall could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted']. warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric f1 could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted']. warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric roc_auc could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: multi_class must be in ('ovo', 'ovr') warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric avg_precision could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: multiclass format is not supported warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric hinge could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: The shape of pred_decision cannot be 1d arraywith a multiclass target. pred_decision shape must be (n_samples, n_classes), that is (1977, 3). Got: (1977,) warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric brier could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: Only binary classification is supported. The type of the target is multiclass. warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric roc_curve could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: multiclass format is not supported warnings.warn(str(v)) /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:106: UserWarning: metric pr_curve could not be evaluated warnings.warn(f"metric {metric_name} could not be evaluated") /usr/local/lib/python3.8/dist-packages/causallib/evaluation/metrics.py:107: UserWarning: multiclass format is not supported warnings.warn(str(v))

KeyError Traceback (most recent call last) in 3 4 eval_results = evaluate(ipw, X, a, y) ----> 5 eval_results.plot_all() 6 eval_results.plot_covariate_balance(kind="love");

8 frames /usr/local/lib/python3.8/dist-packages/causallib/evaluation/plots/mixins.py in plot_all(self, phase) 343 """ 344 phases_to_plot = self.predictions.keys() if phase is None else [phase] --> 345 multipanel_plot = { 346 plotted_phase: self._make_multipanel_evaluation_plot( 347 plot_names=self.all_plot_names, phase=plotted_phase

/usr/local/lib/python3.8/dist-packages/causallib/evaluation/plots/mixins.py in (.0) 344 phases_to_plot = self.predictions.keys() if phase is None else [phase] 345 multipanel_plot = { --> 346 plotted_phase: self._make_multipanel_evaluation_plot( 347 plot_names=self.all_plot_names, phase=plotted_phase 348 )

/usr/local/lib/python3.8/dist-packages/causallib/evaluation/plots/mixins.py in _make_multipanel_evaluation_plot(self, plot_names, phase) 353 def _make_multipanel_evaluation_plot(self, plot_names, phase): 354 phase_fig, phase_axes = plots.get_subplots(len(plot_names)) --> 355 named_axes = { 356 name: self._make_single_panel_evaluation_plot(name, phase, ax) 357 for name, ax in zip(plot_names, phase_axes.ravel())

/usr/local/lib/python3.8/dist-packages/causallib/evaluation/plots/mixins.py in (.0) 354 phase_fig, phase_axes = plots.get_subplots(len(plot_names)) 355 named_axes = { --> 356 name: self._make_single_panel_evaluation_plot(name, phase, ax) 357 for name, ax in zip(plot_names, phase_axes.ravel()) 358 }

/usr/local/lib/python3.8/dist-packages/causallib/evaluation/plots/mixins.py in _make_single_panel_evaluation_plot(self, plot_name, phase, ax, *kwargs) 379 plot_func = plots.lookup_name(plot_name) 380 plot_data = self.get_data_for_plot(plot_name, phase=phase) --> 381 return plot_func(plot_data, ax=ax, **kwargs)

/usr/local/lib/python3.8/dist-packages/causallib/evaluation/plots/plots.py in plot_mean_features_imbalance_love_folds(table1_folds, cv, aggregate_folds, thresh, plot_semi_grid, ax) 813 aggregated_table1 = aggregated_table1.groupby(aggregated_table1.index) 814 --> 815 order = aggregated_table1.mean().sort_values(by="unweighted", ascending=True).index 816 817 if aggregate_folds:

/usr/local/lib/python3.8/dist-packages/pandas/util/_decorators.py in wrapper(*args, *kwargs) 309 stacklevel=stacklevel, 310 ) --> 311 return func(args, **kwargs) 312 313 return wrapper

/usr/local/lib/python3.8/dist-packages/pandas/core/frame.py in sort_values(self, by, axis, ascending, inplace, kind, na_position, ignore_index, key) 6257 6258 by = by[0] -> 6259 k = self._get_label_or_level_values(by, axis=axis) 6260 6261 # need to rewrap column in Series to apply key function

/usr/local/lib/python3.8/dist-packages/pandas/core/generic.py in _get_label_or_level_values(self, key, axis) 1777 values = self.axes[axis].get_level_values(key)._values 1778 else: -> 1779 raise KeyError(key) 1780 1781 # Check for duplicates

KeyError: 'unweighted'

Can you help me out please?

ehudkr commented 1 year ago

Hi, I'm sorry you encountered a problem, and thank you for bringing that up.

It's hard to say without the input data, but on the face of it, it seems you have a multiple-treatment setting that is not well-defined for most evaluations which require a binary treatment. For example, the Love plot take a covariate x compute a the difference of means between two groups: x|a=1 and x|a=0. However, if you have more than two groups, then this difference is no longer well-defined. (Same argument goes for the other evaluations. For example, ROC-curves are not well-defined for non-binary target, etc.)

Let me know if that isn't the case. And if so, if you could please share a small synthetic data sample (or snippet) that reproduces the problem, it would make it easier.

jgdpsingh commented 1 year ago

Thank you @ehudkr .. will try for binary targets once and will see how it goes. Will post it here

jgdpsingh commented 1 year ago

It worked.. changed the target parameter to binary instead of 3 classes.. thanks a lot @ehudkr ..

But kindly work out something for multiple class problems as well