mlr-org / mlr3tuning

Hyperparameter optimization package of the mlr3 ecosystem
https://mlr3tuning.mlr-org.com/
GNU Lesser General Public License v3.0
53 stars 5 forks source link

Tuning SMOTE's K with a trafo fails: 'warning("k should be less than sample size!")' #239

Closed andreassot10 closed 1 month ago

andreassot10 commented 4 years ago

Hello,

I'm having trouble with the trafo function forSMOTE {smotefamily}'s K parameter. In particular, when the number of nearest neighbours K is greater than or equal to the sample size, an error is returned (warning("k should be less than sample size!")) and the tuning process is terminated.

The user cannot control K to be smaller than the sample size during the internal resampling process. This would have to be controlled internally so that if, for instance, trafo_K = 2 ^ K >= sample_size for some value of K, then, say, trafo_K = sample_size - 1.

I was wondering if there's a solution to this or if one is already on its way?

library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    x[[index]] <- round(3 ^ x[[index]]) #  Intentionally define a trafo that won't work
  }
  x
}

# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 3), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

And here's what happens

INFO  [11:00:14.904] Benchmark with 2 resampling iterations 
INFO  [11:00:14.919] Applying learner 'smote.rf' on task 'optdigits' (iter 2/2) 
Error in get.knnx(data, query, k, algorithm) : ANN: ERROR------->
In addition: Warning message:
In get.knnx(data, query, k, algorithm) : k should be less than sample size!

Session info

R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] smotefamily_1.3.1        OpenML_1.10              mlr3viz_0.1.1.9002      
 [4] mlr3tuning_0.1.2-9000    mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0          
 [7] mlr3learners_0.2.0       mlr3filters_0.2.0.9000   mlr3_0.2.0-9000         
[10] paradox_0.2.0            yardstick_0.0.5          rsample_0.0.5           
[13] recipes_0.1.9            parsnip_0.0.5            infer_0.5.1             
[16] dials_0.0.4              scales_1.1.0             broom_0.5.4             
[19] tidymodels_0.0.3         reshape2_1.4.3           janitor_1.2.1           
[22] data.table_1.12.8        forcats_0.4.0            stringr_1.4.0           
[25] dplyr_0.8.4              purrr_0.3.3              readr_1.3.1             
[28] tidyr_1.0.2              tibble_3.0.1             ggplot2_3.3.0           
[31] tidyverse_1.3.0         

loaded via a namespace (and not attached):
  [1] utf8_1.1.4              tidyselect_1.0.0        lme4_1.1-21            
  [4] htmlwidgets_1.5.1       grid_3.6.2              ranger_0.12.1          
  [7] pROC_1.16.1             munsell_0.5.0           codetools_0.2-16       
 [10] bbotk_0.1               DT_0.12                 future_1.17.0          
 [13] miniUI_0.1.1.1          withr_2.2.0             colorspace_1.4-1       
 [16] knitr_1.28              uuid_0.1-4              rstudioapi_0.10        
 [19] stats4_3.6.2            bayesplot_1.7.1         listenv_0.8.0          
 [22] rstan_2.19.2            lgr_0.3.4               DiceDesign_1.8-1       
 [25] vctrs_0.2.4             generics_0.0.2          ipred_0.9-9            
 [28] xfun_0.12               R6_2.4.1                markdown_1.1           
 [31] mlr3measures_0.1.3-9000 rstanarm_2.19.2         lhs_1.0.1              
 [34] assertthat_0.2.1        promises_1.1.0          nnet_7.3-12            
 [37] gtable_0.3.0            globals_0.12.5          processx_3.4.1         
 [40] timeDate_3043.102       rlang_0.4.5             workflows_0.1.1        
 [43] BBmisc_1.11             splines_3.6.2           checkmate_2.0.0        
 [46] inline_0.3.15           yaml_2.2.1              modelr_0.1.5           
 [49] tidytext_0.2.2          threejs_0.3.3           crosstalk_1.0.0        
 [52] backports_1.1.6         httpuv_1.5.2            rsconnect_0.8.16       
 [55] tokenizers_0.2.1        tools_3.6.2             lava_1.6.6             
 [58] ellipsis_0.3.0          ggridges_0.5.2          Rcpp_1.0.4.6           
 [61] plyr_1.8.5              base64enc_0.1-3         visNetwork_2.0.9       
 [64] ps_1.3.0                prettyunits_1.1.1       rpart_4.1-15           
 [67] zoo_1.8-7               haven_2.2.0             fs_1.3.1               
 [70] furrr_0.1.0             magrittr_1.5            colourpicker_1.0       
 [73] reprex_0.3.0            GPfit_1.0-8             SnowballC_0.6.0        
 [76] packrat_0.5.0           matrixStats_0.55.0      tidyposterior_0.0.2    
 [79] hms_0.5.3               shinyjs_1.1             mime_0.8               
 [82] xtable_1.8-4            XML_3.99-0.3            tidypredict_0.4.3      
 [85] shinystan_2.5.0         readxl_1.3.1            gridExtra_2.3          
 [88] rstantools_2.0.0        compiler_3.6.2          crayon_1.3.4           
 [91] minqa_1.2.4             StanHeaders_2.21.0-1    htmltools_0.4.0        
 [94] later_1.0.0             lubridate_1.7.4         DBI_1.1.0              
 [97] dbplyr_1.4.2            MASS_7.3-51.4           boot_1.3-23            
[100] Matrix_1.2-18           cli_2.0.1               parallel_3.6.2         
[103] gower_0.2.1             igraph_1.2.4.2          pkgconfig_2.0.3        
[106] xml2_1.2.2              foreach_1.4.7           dygraphs_1.1.1.6       
[109] prodlim_2019.11.13      farff_1.1               rvest_0.3.5            
[112] snakecase_0.11.0        janeaustenr_0.1.5       callr_3.4.1            
[115] digest_0.6.25           cellranger_1.1.0        curl_4.3               
[118] shiny_1.4.0             gtools_3.8.1            nloptr_1.2.1           
[121] lifecycle_0.2.0         nlme_3.1-142            jsonlite_1.6.1         
[124] fansi_0.4.1             pillar_1.4.3            lattice_0.20-38        
[127] loo_2.2.0               fastmap_1.0.1           httr_1.4.1             
[130] pkgbuild_1.0.6          survival_3.1-8          glue_1.4.0             
[133] xts_0.12-0              FNN_1.1.3               shinythemes_1.1.2      
[136] iterators_1.0.12        class_7.3-15            stringi_1.4.4          
[139] memoise_1.1.0           future.apply_1.5.0     

Many thanks.

andreassot10 commented 4 years ago

Also on StackOverflow

pfistfl commented 4 years ago

I think this can be seen from two perspectives:

andreassot10 commented 4 years ago

I think this can be seen from two perspectives:

* If you specify an invalid parameter, the method **should** fail, and a backup learner might catch the error.
  In this case, we really need to provide an example in the mlr3gallery on how this is done.

* We as package maintainers try to robustify implementations to that degree. This results in equal results for various parameters `K` in this instance. In the latter case, I think this should be moved to `mlr3pipelines`, as this basically is a problem of how we implemented the particular `PipeOp`.
  In general, though we might also wanna define how we want to deal with such situations where we'd basically have to manipulate hyperparameters so downstream tasks do not fail.

These are good points. I agree that the method should fail, when the user specifies an upper bound for K that is greater than or equal to the number of records in the (training) data. But I think that there is a certain peculiarity when it comes to SMOTE: even when the supplied upper bound for K is smaller than the number of records, the method still may fail. That's because SMOTE we may want to tune the pipeline with e.g. CV, where SMOTE is applied on the data subsets (e.g. CV folds). The number of records of these subsets may or may not be greater than K. Each and every time the user changes the train/test split and/or the number of CV folds, they also have to readjust the upper bound for K.

It's a tricky one, because there is a certain degree of validity in both points you are making, while it's still not easy to decide which one to go for.

I've implemented a fix here with a trafo that forces K to always be smaller than the number of records in the data subsets. It does come with drawbacks though, as it is likely that the value of K can be the same in different runs. Also, note that it depends on CV instantiation, so cv$instantiate(task) will have to precede certain commands.

I hope this helps.

berndbischl commented 4 years ago

1) Ok, a couple of comments. I think this is somewhat of a "fringe issue"; sorry @andreassot10. Currently, you might have to live with this not working out for you ideally.

2) The value of K refers to not was in the task, completely, but what is seen inside of the training set. Thats just how it is.

But you basically want is that K changes, depending of the size of the training set, correct?

berndbischl commented 4 years ago

We can either implement that in SMOTE, as a robust behavior of the algorithm, then its a mlr3pipelines issue, or we can make the trafos better.

@pfistfl what do you think?

andreassot10 commented 4 years ago
1. Ok, a couple of comments. I think this is somewhat of a "fringe issue"; sorry @andreassot10.
   Currently, you might have to live with this not working out for you ideally.

2. The value of K refers to not was in the task, completely, but what is seen inside of the training set.
   Thats just how it is.

But you basically want is that K changes, depending of the size of the training set, correct?

Thanks @berndbischl. My answers to your comments in reverse order:

2 Exactly. That's where SMOTE is likely to fail. So I'd indeed like the upper bound of K to change depending of the size of the training set. So if you're running a k-fold cross-validation, K should be smaller than the number of rows of the dataset comprising of the k-1 folds. 1 Can definitely live with it for the time being. If I come up with better fix that ensures the values of K aren't duplicated every now and then, I'll make it public.

Thanks

be-marc commented 1 month ago

This should be possible now with callbacks.