scikit-learn / scikit-learn

scikit-learn: machine learning in Python
https://scikit-learn.org
BSD 3-Clause "New" or "Revised" License
60.09k stars 25.4k forks source link

add sklearn.metrics Display class to plot Precision/Recall/F1 for probability thresholds #21391

Open dayyass opened 3 years ago

dayyass commented 3 years ago

Describe the workflow you want to enable

Working with binary classifiers I often, in addition to PR-curve and ROC-curve, need Precision / Recall / F1 (y-axis) for probability thresholds (x-axis).

Describe your proposed solution

import numpy as np
from sklearn.metrics import precision_recall_curve, PrecisionRecallF1Display

y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])

precision, recall, thresholds = precision_recall_curve(y_true, y_scores)

display = PrecisionRecallF1Display(precision, recall, thresholds, plot_f1=True)
display.plot()
prf1-curve

Describe alternatives you've considered, if relevant

No response

Additional context

No response

glemaitre commented 3 years ago

We already have the PrecisionRecallDisplay. One missing feature there would be the F1 isoline as we plot it in the final figure in the following example:

https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html#sphx-glr-auto-examples-model-selection-plot-precision-recall-py

dayyass commented 3 years ago

We already have the PrecisionRecallDisplay. One missing feature there would be the F1 isoline as we plot it in the final figure in the following example:

That is not the same, because my implementation has probability thresholds as x-axis.

glemaitre commented 3 years ago

That is not the same, because my implementation has probability thresholds as x-axis.

Yes but I am not sure this would be the best way to represent the precision-recall metrics. You don't get a sense of the average precision in this way of plotting.

dayyass commented 3 years ago

That is not the same, because my implementation has probability thresholds as x-axis.

Yes but I am not sure this would be the best way to represent the precision-recall metrics. You don't get a sense of the average precision in this way of plotting.

You are right, but this plot is for choosing probability thresholds for better model. That plot gives an understanding of what threshold gives a certain ratio of precision and recall.

glemaitre commented 3 years ago

That plot gives an understanding of what threshold gives a certain ratio of precision and recall.

It is true that we don't have any information regarding the threshold and it would be nice to get it on the current type of Display (ROC curve and precision-recall curve). I could imagine requesting displaying the thresholds as annotated markers on the curves. Since it could become crowded, we could accept an API that (i) do not display the marker, (ii) display all thresholds, (iii) display only a given number of thresholds.

glemaitre commented 3 years ago

This would also motivate some example/analysis when introducing the meta-estimator developed in: https://github.com/scikit-learn/scikit-learn/pull/16525

dayyass commented 3 years ago

That plot gives an understanding of what threshold gives a certain ratio of precision and recall.

It is true that we don't have any information regarding the threshold and it would be nice to get it on the current type of Display (ROC curve and precision-recall curve). I could imagine requesting displaying the thresholds as annotated markers on the curves. Since it could become crowded, we could accept an API that (i) do not display the marker, (ii) display all thresholds, (iii) display only a given number of thresholds.

So for PR-curve we preserve recall as x-axis and precision as y-axis, but add/plot thresholds on the curve?

glemaitre commented 3 years ago

Somehow, with the following spirit:

output

But we should not plot all markers and we need to have a nice way to place the annotation. We could even think that to make the annotation of different colours to know if this is above the threshold used to do the argmax (e.g. 0.5 for predict_proba and 0 for decision_function).

dayyass commented 3 years ago

Somehow, with the following spirit:

output

But we should not plot all markers and we need to have a nice way to place the annotation. We could even think that to make the annotation of different colours to know if this is above the threshold used to do the argmax (e.g. 0.5 for predict_proba and 0 for decision_function).

This is nice idea and I am in favor of realizing it, but I think that this plot does not give a complete understanding of probability thresholds behavior compared to the approach I suggested.

When there is a choice of a threshold, lot's of my friends and colleagues plot threshold as x-axis and metrics (precision, recall, f1-score) as y-axis.

dayyass commented 3 years ago

Somehow, with the following spirit:

output

But we should not plot all markers and we need to have a nice way to place the annotation. We could even think that to make the annotation of different colours to know if this is above the threshold used to do the argmax (e.g. 0.5 for predict_proba and 0 for decision_function).

@glemaitre, 1) could you send your implementation of the plot, so that I don't have to implement it from scratch? 2) could you review my PR with the implementation of my plot?

glemaitre commented 3 years ago

could you send your implementation of the plot, so that I don't have to implement it from scratch?

It was a dirty hack to make the plot. I did not keep the changes.

could you review my PR with the implementation of my plot?

I am not convinced that I would like this type of plotting in scikit-learn. I would prefer to improve the current plotting that we have. I would think that adding an option to compute the precision and recall gain as proposed in (http://people.cs.bris.ac.uk/~flach/PRGcurves/PRcurves_ISL.pdf) would be better.

However, this is only my opinion so I would like to hear other core devs regarding the matter, notably @amueller that probably looked at these kinds of plots more than me. Maybe @adrinjalali @thomasjpfan @NicolasHug have some opinions as well.

adrinjalali commented 3 years ago

I'm quite in favor of one or more plots which investigate the threshold. I'd agree with @dayyass that the two types of plots discussed here are very different in nature and in what they convey. I personally understand the proposed plot in the OP much better than a precision recall curve, and that's probably because I have never ended up dealing with precision recall curves much in my career, whereas I've worked on finding thresholds for my models in a few instances.

In terms of API, I'd be more happy with something MetricThresholdCurve which can accept different metrics (precision, recall, f1, ...) and plot it against the thresholds. We can then use it 3 times to achieve what the OP suggests.

dayyass commented 3 years ago

I'm quite in favor of one or more plots which investigate the threshold. I'd agree with @dayyass that the two types of plots discussed here are very different in nature and in what they convey. I personally understand the proposed plot in the OP much better than a precision recall curve, and that's probably because I have never ended up dealing with precision recall curves much in my career, whereas I've worked on finding thresholds for my models in a few instances.

In terms of API, I'd be more happy with something MetricThresholdCurve which can accept different metrics (precision, recall, f1, ...) and plot it against the thresholds. We can then use it 3 times to achieve what the OP suggests.

I agree with the proposed API - I will redo the current implementation.

vitaliset commented 1 year ago

Hi @dayyass, are you still working on this feature? I would love to do it if you don't mind!

dayyass commented 1 year ago

Hi @vitaliset! I don't have enough time to do it right now, so sure you can do it by yourself, it would be great!

vitaliset commented 1 year ago

Hello @dayyass and @adrinjalali! I first implemented the "base" function to build the Display later. I would appreciate it if you could take a quick look at the PR to see if I got the idea right. :)

vitaliset commented 1 year ago

I didn't know the existence of the sklearn.model_selection.validation_curve. It may already solve this issue allied with PR #16525 if the threshold can be passed (which could be interpreted as a hyperparameter search).

It would be less straightforward than using the curve implemented in #25639, but it has the tradeoff of less code to handle. An example or user guide would be handy in this case. I don't know, maybe this separate curve is more intuitive to use, and we want to go forward with it.

What do you think, @glemaitre? As you are building #16525, I think you have a broader view than I do on this.

FernandoCarazoMelo commented 5 months ago

@dayyass . Thanks for your issue, I was looking for the same feature.

Would the following implementation solve your request?

image

dayyass commented 5 months ago

@FernandoCarazoMelo, looks kinda similar of what I wanted to do