ploomber / sklearn-evaluation

Machine learning model evaluation made easy: plots, tables, HTML reports, experiment tracking and Jupyter notebook analysis.
https://sklearn-evaluation.ploomber.io
Apache License 2.0
455 stars 54 forks source link

[Feature] Threshholds in Precision-Recall Multiclass Curve #319

Closed Buedenbender closed 1 year ago

Buedenbender commented 1 year ago

Thank you for this great package.

TL;DR

I would like to obtain the threshholds used for the creation of the mutliclass precision-recall curve with plot.precision-recall() function.

Details

For binary classifications scikit learn offers precision-recall-curve() (docs) which also returns a list of the threshholds. However there is no easy way (without lots of boiler plate code) to get multiple precision-recall curves in a multiclass scenario only utilizing scikit-learn. Luckily sklearn-evaluation provides a way out with the functionality of a multiclass precision-recall curve plot (here), that provides a publication quality plot. However it would be extremely helpful if it would be possible to also return a list (or dict) of threshholds, or alternatively to include them in the plot (e.g., a top border y-axis with the threshhold).

Why ...

.. you may ask? This increases the utility and value one can gain from the curve. After all once finished with inspecting the trade-off between precision and recall one ultimately wants to determine a threshhold.

Implementation plan

Implmenet ... what?

Either (or both?)

How to implement it?

List / dict of threshholds Maybe as a class method (since everything is OOP / a class), e.g., like

@classmethod
def get_threshholds():
    pass

Argument to include them in the plott pass

Where to implement?

Some first insights from just scrolling through the code of precision_recall. Currently the threshholds are discarded _ (Line 48-55)

def _precision_recall_metrics(y_true, y_score):
    precision, recall, _ = precision_recall_curve(y_true, y_score)
    return precision, recall

def _precision_recall_metrics_multiclass(y_true, y_score):
    precision, recall, _ = precision_recall_curve(y_true.ravel(), y_score.ravel())
    return precision, recall
edublancas commented 1 year ago

Hi @Buedenbender, thanks for your feedback!

This makes sense, I think the object-oriented API is an excellent option to enable this because it allows us to store data in the object after plotting.

I'm thinking something like this:

from sklearn_evaluation import plot

pr = plot.PrecisionRecall.from_raw_data(y_true, y_score)

# return the thresholds (note the trailing underscore
# by convention, this means the parameter is set after plotting)
pr.thresholds_

@yafimvo: I believe you worked on this, right? Can you provide some comments? Sounds like this would be straightforward.

@Buedenbender: do you have some time to contribute to this feature? we're happy to guide you

Buedenbender commented 1 year ago

Hi @edublancas I would like to do give it a try and it could be a good to start first issue to tackle.

edublancas commented 1 year ago

Awesome, if you need any help, let us know!

here's the contributing guide: https://github.com/ploomber/sklearn-evaluation/blob/master/CONTRIBUTING.md

Buedenbender commented 1 year ago

Follow up question on the implementation. What you think about the best form of returning? In my mind it would be most beneficial to also return prec and recal values in a nested dict Outer Dict 'name_of_curve': Inner Dict

Inner Dict 'threshhold': [threshholds ints] 'precision': [precision ints] 'recall': [recall ints]

edublancas commented 1 year ago

good point.

we try to stick to sklearn's guidelines as much as we can. according to their guidelines, by convention, they add attributes with trailing commas. I think if there's a single curve, we might add .plot_data_ with threshold, precision and recall keys. If it's multiple curves, then .plot_data_ will have a dict whose keys are the name of the curve as you suggested and inside threshold, precision, recall.

We might encounter some edge cases or find that there are better ways to return them so I'd suggest getting a first implementation and then we decide that's the best way to return them

yafimvo commented 1 year ago

@yafimvo: I believe you worked on this, right? Can you provide some comments? Sounds like this would be straightforward.

I think @neelasha23 worked on this.

As for the implementation, I agree, think it can be quite straightforward.

we try to stick to sklearn's guidelines as much as we can. according to their guidelines, by convention, they add attributes with trailing commas. I think if there's a single curve, we might add .plot_data_ with threshold, precision and recall keys. If it's multiple curves, then .plot_data_ will have a dict whose keys are the name of the curve as you suggested and inside threshold, precision, recall.

In another approach, we can return 3 lists (threshold, precision, recall). This way we don't depend on the curve names and return the same structure every time. We can access every value with an index.

idomic commented 1 year ago

In another approach, we can return 3 lists (threshold, precision, recall). This way we don't depend on the curve names and return the same structure every time. We can access every value with an index.

Yeah, I think that's ideal since it can be generalized.

Buedenbender commented 1 year ago

There was already the _get_data() method, in which I was able to add threshold (here). However I was not quite able to get different lists for all curves (in the multiclass scenario).

I do understand that for the multiclass case the method (__add__) is essential. Sadly, I am not quite sure from how this method is getting called / where it is getting called from.

    @SKLearnEvaluationLogger.log(feature="plot", action="precision-recall-add")
    def __add__(self, another):
        return PrecisionRecallAdd(
            precisions=[self.precision, another.precision],
            recalls=[self.recall, another.recall],
            thresholds=[self.threshold, another.threshold],
            labels=[self.label, another.label],
        ).plot()
edublancas commented 1 year ago

So I think we can rename all instances of _get_data -> get_data to make it part of the public API. I think the only remaining part would be to ensure that the PrecisionRecall object includes the thresholds in the object returned by _get_data, which seems you already know how to do, I think we should store this under the threshold key. This will allow people to just call get_data to get the underlying plotting data.

I do understand that for the multiclass case the method (add) is essential. Sadly, I am not quite sure from how this method is getting called / where it is getting called from.

the __add__ method is used when doing plot1 + plot2; but you can also get the multi-class version without it:

from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from sklearn_evaluation import plot

# generate data
X, y = datasets.make_classification(
    n_samples=2000, n_features=6, n_informative=4, class_sep=0.1, n_classes=3
)

# split data into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

est = RandomForestClassifier()
est.fit(X_train, y_train)

# y_pred = est.predict(X_test)
y_score = est.predict_proba(X_test)
y_true = y_test

# plot precision recall curve
pr = plot.PrecisionRecall.from_raw_data(y_true, y_score)

So to keep this simple, let's not modify the __add__ method.

Now, adding anything to the object returned by _get_data will break PrecisionRecall since it inherits the from_dump method in the AbstractPlot abstract class. I think adding this to PrecisionRecall will fix it:


    @classmethod
    def from_dump(cls, path):
        """Instantiates a plot object from a path to a JSON file. A default
        implementation is provided, but you might override it.
        """
        data = json.loads(Path(path).read_text(encoding="utf-8"))
        del data["version"]
        del data["class"]
        del data["threshold"]
        return cls(**data).plot()
edublancas commented 1 year ago

closing due to inactivity, feel free to re-open