Closed Buedenbender closed 1 year ago
Hi @Buedenbender, thanks for your feedback!
This makes sense, I think the object-oriented API is an excellent option to enable this because it allows us to store data in the object after plotting.
I'm thinking something like this:
from sklearn_evaluation import plot
pr = plot.PrecisionRecall.from_raw_data(y_true, y_score)
# return the thresholds (note the trailing underscore
# by convention, this means the parameter is set after plotting)
pr.thresholds_
@yafimvo: I believe you worked on this, right? Can you provide some comments? Sounds like this would be straightforward.
@Buedenbender: do you have some time to contribute to this feature? we're happy to guide you
Hi @edublancas I would like to do give it a try and it could be a good to start first issue to tackle.
Awesome, if you need any help, let us know!
here's the contributing guide: https://github.com/ploomber/sklearn-evaluation/blob/master/CONTRIBUTING.md
Follow up question on the implementation. What you think about the best form of returning? In my mind it would be most beneficial to also return prec and recal values in a nested dict Outer Dict 'name_of_curve': Inner Dict
Inner Dict 'threshhold': [threshholds ints] 'precision': [precision ints] 'recall': [recall ints]
good point.
we try to stick to sklearn's guidelines as much as we can. according to their guidelines, by convention, they add attributes with trailing commas. I think if there's a single curve, we might add .plot_data_
with threshold, precision and recall keys. If it's multiple curves, then .plot_data_
will have a dict whose keys are the name of the curve as you suggested and inside threshold, precision, recall.
We might encounter some edge cases or find that there are better ways to return them so I'd suggest getting a first implementation and then we decide that's the best way to return them
@yafimvo: I believe you worked on this, right? Can you provide some comments? Sounds like this would be straightforward.
I think @neelasha23 worked on this.
As for the implementation, I agree, think it can be quite straightforward.
we try to stick to sklearn's guidelines as much as we can. according to their guidelines, by convention, they add attributes with trailing commas. I think if there's a single curve, we might add
.plot_data_
with threshold, precision and recall keys. If it's multiple curves, then.plot_data_
will have a dict whose keys are the name of the curve as you suggested and inside threshold, precision, recall.
In another approach, we can return 3 lists (threshold, precision, recall). This way we don't depend on the curve names and return the same structure every time. We can access every value with an index.
In another approach, we can return 3 lists (threshold, precision, recall). This way we don't depend on the curve names and return the same structure every time. We can access every value with an index.
Yeah, I think that's ideal since it can be generalized.
There was already the _get_data()
method, in which I was able to add threshold (here).
However I was not quite able to get different lists for all curves (in the multiclass scenario).
I do understand that for the multiclass case the method (__add__
) is essential.
Sadly, I am not quite sure from how this method is getting called / where it is getting called from.
@SKLearnEvaluationLogger.log(feature="plot", action="precision-recall-add")
def __add__(self, another):
return PrecisionRecallAdd(
precisions=[self.precision, another.precision],
recalls=[self.recall, another.recall],
thresholds=[self.threshold, another.threshold],
labels=[self.label, another.label],
).plot()
So I think we can rename all instances of _get_data
-> get_data
to make it part of the public API. I think the only remaining part would be to ensure that the PrecisionRecall
object includes the thresholds in the object returned by _get_data
, which seems you already know how to do, I think we should store this under the threshold
key. This will allow people to just call get_data
to get the underlying plotting data.
I do understand that for the multiclass case the method (add) is essential. Sadly, I am not quite sure from how this method is getting called / where it is getting called from.
the __add__
method is used when doing plot1 + plot2
; but you can also get the multi-class version without it:
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn_evaluation import plot
# generate data
X, y = datasets.make_classification(
n_samples=2000, n_features=6, n_informative=4, class_sep=0.1, n_classes=3
)
# split data into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
est = RandomForestClassifier()
est.fit(X_train, y_train)
# y_pred = est.predict(X_test)
y_score = est.predict_proba(X_test)
y_true = y_test
# plot precision recall curve
pr = plot.PrecisionRecall.from_raw_data(y_true, y_score)
So to keep this simple, let's not modify the __add__
method.
Now, adding anything to the object returned by _get_data
will break PrecisionRecall
since it inherits the from_dump
method in the AbstractPlot
abstract class. I think adding this to PrecisionRecall
will fix it:
@classmethod
def from_dump(cls, path):
"""Instantiates a plot object from a path to a JSON file. A default
implementation is provided, but you might override it.
"""
data = json.loads(Path(path).read_text(encoding="utf-8"))
del data["version"]
del data["class"]
del data["threshold"]
return cls(**data).plot()
closing due to inactivity, feel free to re-open
Thank you for this great package.
TL;DR
I would like to obtain the threshholds used for the creation of the mutliclass precision-recall curve with
plot.precision-recall()
function.Details
For binary classifications scikit learn offers
precision-recall-curve()
(docs) which also returns a list of the threshholds. However there is no easy way (without lots of boiler plate code) to get multiple precision-recall curves in a multiclass scenario only utilizing scikit-learn. Luckily sklearn-evaluation provides a way out with the functionality of a multiclass precision-recall curve plot (here), that provides a publication quality plot. However it would be extremely helpful if it would be possible to also return a list (or dict) of threshholds, or alternatively to include them in the plot (e.g., a top border y-axis with the threshhold).Why ...
.. you may ask? This increases the utility and value one can gain from the curve. After all once finished with inspecting the trade-off between precision and recall one ultimately wants to determine a threshhold.
Implementation plan
Implmenet ... what?
Either (or both?)
add_threshholds=False
How to implement it?
List / dict of threshholds Maybe as a class method (since everything is OOP / a class), e.g., like
Argument to include them in the plott pass
Where to implement?
Some first insights from just scrolling through the code of precision_recall. Currently the threshholds are discarded
_
(Line 48-55)