mljar / mljar-supervised

Python package for AutoML on Tabular Data with Feature Engineering, Hyper-Parameters Tuning, Explanations and Automatic Documentation
https://mljar.com
MIT License
2.99k stars 400 forks source link

Metrics computation with imbalance threshold #625

Open anatolereffet opened 1 year ago

anatolereffet commented 1 year ago

Hello Piotr !

Great library so far I've been enjoying trying out MLJAR, however I am wondering why are we using the threshold (binary classification case) leading to the highest accuracy to compute the confusion matrix instead of using the class imbalance in the dataset for example ?

We work with highly imbalanced datasets and are used to setting up threshold on confusion matrix according to how much target % is in the dataset while usually splitting with stratifying enabled.

Can this be easily fixed by shifting the additional_metrics.py file regarding the accuracy threshold metrics section to be chosen by the user or does it need to be hard coded (As it is for accuracy even if the user is using another metric such as AUC)?

Best regards

pplonski commented 1 year ago

Hi @anatolereffet,

Great question. We are using accuracy to find best threshold because it is the most common metric for binary classification and we optimise to maximise it.

You can apply your own threshold technique after getting probabilities from predict_proba(). Please let me know if you need any help.

anatolereffet commented 1 year ago

Thanks for your precise answer, yes I have seen the predict_proba() method but this requires to only use the "best model" considered by MLJAR. I have seen as well the #423 showing it is WIP. I'm more interested in changing the reports threshold for the models overall. Is it possible to do so with in-library functions or do I have to implement a change locally to fit my own needs ?

Thanks again for your time !

pplonski commented 1 year ago

Hi @anatolereffet,

you need to implement change locally. Here is the line for setting the threshold https://github.com/mljar/mljar-supervised/blob/35584462ed0fc6e7345f4999b1019c0990598c07/supervised/utils/additional_metrics.py#L138

Please let me know if you need more help.