Open mimicarina opened 1 year ago
I have a probably related issue with categorical columns that contain integer numbers. Calling Dice.generate_counterfactuals
raises:
ValueError: Found unknown categories ['9', '2', '13', '7', '5', '12', '11', '15', '18', '3', '1', '14', '8', '10', '17', '4', '16'] in column 2 during transform
I realised that Data.permitted_range
already has integers of categorical columns converted to strings, that's probably the root cause of the problem. Having only number
and category
type columns in my dataframe, I get it fixed with:
data = dice_ml.Data(dataframe=df_train, continuous_features=df_train.select_dtypes("number").columns, outcome_name="y")
for col in df_train.select_dtypes("category").columns:
data.permitted_range[col] = df_train[col].cat.categories
Edit: This only works for Dice(method="random")
not for "genetic"
or "kdtree"
.
Edit2: The actual culprit may be PublicData._set_feature_dtypes
where each column in categorical_feature_names
is converted to str
before being converted to category
. However when tweaking the source code and omitting the string conversion, I get another error from the genetic algorithm's LabelEncoder which encodes to int64, which in turn cannot be handled in an numpy-internal np.isnan
check.
https://github.com/interpretml/DiCE/blob/e9e7147fce95b09de2f38d21cf59e0031dee28ae/dice_ml/data_interfaces/private_data_interface.py#L336
When categorical columns contain numerical levels (e.g. yes - 1, no - 0)
visualize_as_dataframe(show_only_changes=True)
(and alsovisualize_as_list()
) does not work, as it encodes the string values to numerical.Example dataset: https://archive.ics.uci.edu/ml/machine-learning-databases/00573/SouthGermanCredit.zip.
During data prep, categorical values are encoded as 'category' data type (see query instance below). The counterfactual uses numeric representation; hence it will show as 'changed' value even though it is the same category (e.g. '2' vs 2).
This is happening because
train_data[cat_feature].cat.categories.tolist()
returns integer and not categories/strings; for sample dataset above the categories and levels are: