interpretml / DiCE

Generate Diverse Counterfactual Explanations for any machine learning model.
https://interpretml.github.io/DiCE/
MIT License
1.35k stars 186 forks source link

Deal with missing values #224

Open londumas opened 3 years ago

londumas commented 3 years ago

There does not seem to be a support for missing values currently. For example dealing with the 'Age' feature in the Titanic dataset:

import pandas as pd
import sklearn
import catboost
import dice_ml

def full_data_set(dataset='train'):
    df = pd.read_csv(
        "../{}.csv".format(dataset))
    if dataset=='test':
        df_truth = pd.read_csv(
        "../gender_submission.csv")
        df = pd.concat([df, df_truth],axis=1)

    ###
    for c in ['Cabin','Embarked']:
        df[c] = df[c].apply(str)

    drops = ['PassengerId','Name','Ticket','Cabin','Embarked']
    df = df.drop(drops,axis=1)

    return df

dataset = full_data_set('train')

###
X_train = dataset.copy()
X_train = X_train.drop(['Survived'],axis=1)
y_train = dataset['Survived']

cat_features = ['Sex']
cat_features_index = [i for i, col in enumerate(X_train.columns) if col in cat_features]

model = catboost.CatBoostClassifier(iterations=200,
                           depth=5,
                           learning_rate=1,
                           loss_function='Logloss',
                           verbose=20,
                           task_type="CPU",
                           devices='0',
                           cat_features=cat_features_index)
# train the model
_ = model.fit(X_train, y_train)

###
dataset = full_data_set('train')

continuous_features = [ el for el in dataset.columns if (el!='Survived') and (el not in cat_features)]
print("continuous_features = ",continuous_features)

d = dice_ml.Data(dataframe=dataset,
                 continuous_features=continuous_features,
                 outcome_name='Survived')

# provide the trained ML model to DiCE's model object
backend = 'sklearn'
m = dice_ml.Model(model=model, backend=backend)

# initiate DiCE
exp_random = dice_ml.Dice(d, m, method="random")

for idx in [0, 888]:

    query_instances = X_train[idx:idx+1]

    # generate counterfactuals
    dice_exp_random = exp_random.generate_counterfactuals(query_instances,
        total_CFs=4, desired_class="opposite", verbose=False)

    dice_exp_random.visualize_as_dataframe(show_only_changes=True)
0:  learn: 0.4365535    total: 54.3ms   remaining: 10.8s
20: learn: 0.2308290    total: 100ms    remaining: 855ms
40: learn: 0.1532110    total: 135ms    remaining: 524ms
60: learn: 0.1088515    total: 166ms    remaining: 379ms
80: learn: 0.0825711    total: 195ms    remaining: 286ms
100:    learn: 0.0716318    total: 221ms    remaining: 217ms
120:    learn: 0.0629166    total: 247ms    remaining: 162ms
140:    learn: 0.0562215    total: 274ms    remaining: 115ms
160:    learn: 0.0522724    total: 302ms    remaining: 73.1ms
180:    learn: 0.0480928    total: 329ms    remaining: 34.6ms
199:    learn: 0.0464085    total: 355ms    remaining: 0us
continuous_features =  ['Pclass', 'Age', 'SibSp', 'Parch', 'Fare']
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.12it/s]
Query instance (original outcome : 0)
   Pclass   Sex   Age  SibSp  Parch  Fare  Survived
0       3  male  22.0      1      0  7.25         0

Diverse Counterfactual set (new outcome: 1.0)
  Pclass     Sex   Age SibSp Parch    Fare Survived
0      -  female     -     -     -  112.29        1
1      -  female     -     -   5.0       7        1
2      -       -  15.5     -     -   449.6        1
3      -  female     -     -     -  431.99        1
  0%|                                                                                                                                                                                                                                                               | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "conterfactual-malwares-titanic.py", line 67, in <module>
    total_CFs=4, desired_class="opposite", verbose=False)
  File "<home>/.local/lib/python3.7/site-packages/dice_ml/explainer_interfaces/explainer_base.py", line 102, in generate_counterfactuals
    **kwargs)
  File "<home>/.local/lib/python3.7/site-packages/dice_ml/explainer_interfaces/dice_random.py", line 171, in _generate_counterfactuals
    final_cfs_df_sparse, test_instance_df, posthoc_sparsity_param, posthoc_sparsity_algorithm)
  File "<home>/.local/lib/python3.7/site-packages/dice_ml/explainer_interfaces/explainer_base.py", line 420, in do_posthoc_sparsity_enhancement
    diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
ValueError: cannot convert float NaN to integer
gaugup commented 3 years ago

@londumas does some sample of the query_instance has a missing value? I think if dice-ml doesn't handle this scenario it is ok. Shouldn't the user woryy about supplying a legitimate value for all columns in query_instance? How would dice-ml go about generating CFs if the value in a column is not known apriori? We should probably raise an exception that the query_instance has missing values rather than erroring out during generation of counterfactuals.

Regards,

candalfigomoro commented 2 years ago

@gaugup I think a missing value is a legitimate value. Packages such as xgboost and lightgbm support features with missing values. I think DiCE should also handle missing values without having to impute them.

urigott commented 10 months ago

I came across this issue as well. In many cases missing values are very informative (just as much as "real" values), and since DiCE can handle models that accept missing values (such as LGBM, XGboost, CatBoost), it would be great if it was capable of handling missing values.