casper-hansen / Nested-Cross-Validation

Nested cross-validation for unbiased predictions. Can be used with Scikit-Learn, XGBoost, Keras and LightGBM, or any other estimator that implements the scikit-learn interface.
MIT License
63 stars 20 forks source link

Add support for y being multiclass in combination with predict_proba=True #9

Open casper-hansen opened 5 years ago

casper-hansen commented 5 years ago

Use this code to see why it won't work; matrices are of different dimensions.

def _predict_and_score(self, X_test, y_test):
        #XXX: Implement type_of_target(y)

        if(self.predict_proba):
            y_type = type_of_target(y_test)
            if(y_type in ('binary')):
                pred = self.model.predict_proba(X_test)[:,1]
            else:
                pred = self.model.predict_proba(X_test)

        else:
            pred = self.model.predict(X_test)

        # Print predictions and y_test
        print(pred, pred.shape, y_test)

        if(self.multiclass_average == 'binary'):
            return self.metric(y_test, pred), pred
        else:
            return self.metric(y_test, pred, average=self.multiclass_average), pred