sagyome / XGBoostTreeApproximator

This project implements a method that converts a trained XGBoost classification tree into a single decision tree.
13 stars 6 forks source link

error happened when pruning #4

Open joshua-xia opened 1 year ago

joshua-xia commented 1 year ago

def predict_probas_tree(self,conjunctions,X): """ Predict probabilities for X using a tree, represented as a conjunction set

    :param conjunctions: A list of conjunctions
    :param X: numpy array of data instances
    :return: class probabilities for each instance of X
    """

    probas = []
    for inst in X:
        for conj in conjunctions:
            if conj.containsInstance(inst):
                probas.append(conj.label_probas)
    return np.array(probas)

the length of return probas not equal to length of X, then the error happened when get_auc:

selected_indexes = [np.argmax([get_auc(Y,trees_predictions[i]) for i in trees_predictions])]

joshua-xia commented 1 year ago

ValueError Traceback (most recent call last)

in 9 10 ---> 11 fbt.fit(train, feature_key, label_col, model) ~/workspace/explain_AI/explain_AI/FBT.py in fit(self, train, feature_cols, label_col, xgb_model, pruned_forest, trees_conjunctions_total) 49 self.trees_conjunctions_total = extractConjunctionSetsFromForest(self.xgb_model,train[self.label_col].unique(),self.feature_cols) 50 print('Start pruning') ---> 51 self.prune(train) 52 else: 53 self.pruner = Pruner() ~/workspace/explain_AI/explain_AI/FBT.py in prune(self, train) 72 if self.pruning_method == 'auc': 73 self.trees_conjunctions = self.pruner.max_auc_pruning(self.trees_conjunctions_total, train[self.feature_cols], ---> 74 train[self.label_col], min_forest_size=self.min_forest_size) 75 76 def predict_proba(self,X): ~/workspace/explain_AI/explain_AI/pruning.py in max_auc_pruning(self, forest, X, Y, min_forest_size) 119 X = X.values 120 trees_predictions = {i: self.predict_probas_tree(forest[i],X) for i in range(len(forest))} #predictions are stored beforehand for efficiency purposes --> 121 selected_indexes = [np.argmax([get_auc(Y,trees_predictions[i]) for i in trees_predictions])] #get the tree with the highest AUC for the given dataset 122 previous_auc = 0 123 best_auc = get_auc(Y,trees_predictions[selected_indexes[0]]) ~/workspace/explain_AI/explain_AI/pruning.py in (.0) 119 X = X.values 120 trees_predictions = {i: self.predict_probas_tree(forest[i],X) for i in range(len(forest))} #predictions are stored beforehand for efficiency purposes --> 121 selected_indexes = [np.argmax([get_auc(Y,trees_predictions[i]) for i in trees_predictions])] #get the tree with the highest AUC for the given dataset 122 previous_auc = 0 123 best_auc = get_auc(Y,trees_predictions[selected_indexes[0]]) ~/workspace/explain_AI/explain_AI/utils.py in get_auc(test_y, y_score) 30 classes=[i for i in range(y_score.shape[1])] 31 y_test_binarize=np.array([[1 if i ==c else 0 for c in classes] for i in test_y]) ---> 32 fpr, tpr, _ = roc_curve(y_test_binarize.ravel(), y_score.ravel()) 33 return auc(fpr, tpr) 34 /opt/conda/envs/main/lib/python3.6/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs) 61 extra_args = len(args) - len(all_args) 62 if extra_args <= 0: ---> 63 return f(*args, **kwargs) 64 65 # extra_args > 0 /opt/conda/envs/main/lib/python3.6/site-packages/sklearn/metrics/_ranking.py in roc_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate) 912 """ 913 fps, tps, thresholds = _binary_clf_curve( --> 914 y_true, y_score, pos_label=pos_label, sample_weight=sample_weight) 915 916 # Attempt to drop thresholds corresponding to points in between and /opt/conda/envs/main/lib/python3.6/site-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight) 691 raise ValueError("{0} format is not supported".format(y_type)) 692 --> 693 check_consistent_length(y_true, y_score, sample_weight) 694 y_true = column_or_1d(y_true) 695 y_score = column_or_1d(y_score) /opt/conda/envs/main/lib/python3.6/site-packages/sklearn/utils/validation.py in check_consistent_length(*arrays) 318 if len(uniques) > 1: 319 raise ValueError("Found input variables with inconsistent numbers of" --> 320 " samples: %r" % [int(l) for l in lengths]) 321 322 ValueError: Found input variables with inconsistent numbers of samples: [20000, 24554]