SeldonIO / alibi

Algorithms for explaining machine learning models
https://docs.seldon.io/projects/alibi/en/stable/
Other
2.38k stars 249 forks source link

IndexError: tuple index out of range #803

Open pranavn91 opened 1 year ago

pranavn91 commented 1 year ago

I used scikit-learn 0.24.2 to train a random forest classifier and used CounterfactualProto - as given in below link

(https://docs.seldon.io/projects/alibi/en/stable/examples/cfproto_housing.html)

from alibi.explainers import CounterfactualProto y30cf = np.zeros((y30.shape[0],)) y30cf[np.where(y30 > np.median(y30))[0]] = 1

y becomes classification task

y30cf array([1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.])

model trained and ready

bestmodel RandomForestClassifier(ccp_alpha=0.1, criterion='entropy', max_depth=1, max_features=None, max_samples=0.5, min_impurity_decrease=0.1, min_samples_leaf=0.5, min_samples_split=0.1, min_weight_fraction_leaf=0.5, n_estimators=1)

took one sample

X = X_test[1].reshape((1,) + X_test[1].shape) shape = X.shape

shape (1, 39)

X array([[ 0. , 0. , 0. , 25. , 0. , 75. , 0. , 0. , 0. , 0. , 70. ,

  1. , 0. , 0. , 0. , 55. , 0. , 0. , 40. , 0. , 0. , 0. ,
  2. , 45. , 0. , 1.5, 2. , 0. , 3. , 0. , 0. , 0.5, 0. ,
  3. , 0. , 0.3, 0. , 0. , 8. ]])

define a black-box model

predict_fn = lambda x: bestmodel.predict(x)

I am getting IndexError: tuple index out of range. How to resolve this error?

cf = CounterfactualProto(predict_fn, shape, use_kdtree=True, theta=10., max_iterations=1000, feature_range=(X_train.min(axis=0), X_train.max(axis=0)), c_init=1., c_steps=10)


IndexError Traceback (most recent call last) Input In [27], in <cell line: 2>() 1 # initialize explainer, fit and generate counterfactual ----> 2 cf = CounterfactualProto(predict_fn, shape, use_kdtree=True, theta=10., max_iterations=1000, 3 feature_range=(X_train.min(axis=0), X_train.max(axis=0)), 4 c_init=1., c_steps=10)

File C:\ProgramData\Anaconda3\lib\site-packages\alibi\explainers\cfproto.py:139, in CounterfactualProto.init(self, predict, shape, kappa, beta, feature_range, gamma, ae_model, enc_model, theta, cat_vars, ohe, use_kdtree, learning_rate_init, max_iterations, c_init, c_steps, eps, clip, update_num_grad, write_dir, sess) 137 else: # black-box model 138 self.model = False --> 139 self.classes = self.predict(np.zeros(shape)).shape[1] 141 if is_enc: 142 self.enc_model = True

IndexError: tuple index out of range

mauicv commented 1 year ago

Hey @pranavn91, Thanks for opening the issue.

The Counterfactuals with prototypes require that the model, be it a black or white box model, be differentiable which random forests aren't. Hence I'm not sure you'll get good results using this method. You might want to look at the CounterfactualRL method instead. (This example uses a random forest classifier)

w.r.t. your issue I'm finding it difficult to recreate. The following:

from sklearn.ensemble import RandomForestClassifier

X = x_test[1].reshape((1,) + x_test[1].shape)
shape = X.shape

clf = RandomForestClassifier(ccp_alpha=0.1, criterion='entropy', max_depth=1,
        max_features=None, max_samples=0.5,
        min_impurity_decrease=0.1, min_samples_leaf=0.5,
        min_samples_split=0.1, min_weight_fraction_leaf=0.5,
        n_estimators=1)

clf.fit(x_train, y_train)

predict_fn = lambda x: clf.predict(x)
predict_fn(np.zeros(shape)).shape[1]

cf = CounterfactualProto(predict_fn, shape, use_kdtree=True, theta=10., max_iterations=1000,
                         feature_range=(x_train.min(axis=0), x_train.max(axis=0)), 
                         c_init=1., c_steps=10)

Doesn't throw the same error. How does your code differ exactly? If you could copy and paste the entire code altogether it might help.

pranavn91 commented 1 year ago

Thanks I will try the new link. The error goes if i one-hot encode the labels. I did not do this as the data was binary. My bad.