tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
588 stars 88 forks source link

In classification problems, merging `probably` package when determining best threshold. #986

Open SHo-JANG opened 1 year ago

SHo-JANG commented 1 year ago

As far as I can understand, we're using prob_to_class_2 as the default option when predicting class.

prob_to_class_2 <- function(x, object) {
  x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1])
  unname(x)
}

However, in many cases, the threshold is not 0.5. (Especially in imbalanced datasets.)

In this case, I wonder if we could use the threshold_perf() function in the probably package during the tuning process to check if the model is potentially classifying really well.

I think it's a really necessary feature, what do you think?

topepo commented 1 year ago

It is an important feature. After the posit conference, we will be working on post-processing tools and this is one of them.

We'll try to make it natural so that you can treat the threshold parameter like any other tuning parameter. If you use a workflow, it will also adjust the hard class predictions automatically (once you've picked a threshold).

SHo-JANG commented 1 year ago

Thank you so much for all the hard work you do to make the system more complete.

SHo-JANG commented 11 months ago

I think that hyperparameterizing to find the optimal threshold would be time consuming and could lead to overfitting.

Instead , I searched for a way to determine the optimal threshold. related paper

In Section 2.3. Threshold criteria, (6)PredPrev = Obs. This means that we want the class ratio of the predicted result to be equal to the ratio of the observed classes in the trained data, i.e., we use quantile(probs = 1- "Obs class ratio") from the predicted probability vector as the threshold.

The code to implement this in the training process is as follows.


prob_to_class_2_custom <- function(x, object) {
  obs_ratio<- object$fit$y |> mean()
  pred_equal_obs_threshold <- quantile(x,probs = 1-obs_ratio)
  x <- ifelse(x >= pred_equal_obs_threshold, object$lvl[2], object$lvl[1])
  unname(x)
}

I would like to use this function as the default option. However, it seems that I need to redefine the engine to apply this function. Is there any way to use this function in an existing engine?