tdhock / mlr3resampling

Resampling algorithms for mlr3 framework in R
3 stars 1 forks source link

mlr3resampling provides new cross-validation algorithms for the mlr3 framework in R

| [[file:tests/testthat][tests]] | [[][]] | | [[][coverage]] | [[][]] |

** Installation

+begin_src R

install.packages("mlr3resampling")#release version from CRAN

OR: development version from GitHub:

install.packages("remotes") remotes::install_github("tdhock/mlr3resampling")


** Description

For an overview of functionality, [[][please read my recent blog post]].

*** Algorithm 1: cross-validation for comparing train on same and other

See examples in [[][ResamplingSameOtherSizesCV vignette]] and data viz for [[][regression]] and [[][classification]].

A supervised learning algorithm inputs a train set, and outputs a prediction function, which can be used on a test set. If each data point belongs to a group (such as geographic region, year, etc), then how do we know if it is possible to train on one group, and predict accurately on another group? Cross-validation can be used to determine the extent to which this is possible, by first assigning fold IDs from 1 to K to all data (possibly using stratification, usually by group and label). Then we loop over test sets (group/fold combinations), train sets (same group, other groups, all groups), and compute test/prediction accuracy for each combination. Comparing test/prediction accuracy between same and other, we can determine the extent to which it is possible:

This is implemented in =ResamplingSameOtherSizesCV= when you use it on a task that defines the =subset= role, for example the Arizona trees data, for which each row is a pixel in an image, and we want to do binary classification -- does the pixel contain a tree or not?

+begin_src R

data(AZtrees,package="mlr3resampling") table(AZtrees$region3)

NE NW S 1464 1563 2929


We see in the output above that the =region3= column has three values (NE, NW, S). Each represents the region/area in which the pixel was found. If we want good predictions in the south (S), can we train on the north? (NE+NW) We can use the code below to setup the CV experiment. The rows 12,15,18 below represent splits that attempt to answer that question (test.subset=S, train.subsets=other).

+begin_src R

task.obj <- mlr3::TaskClassif$new("AZtrees3", AZtrees, target="y") task.obj$col_roles$feature <- grep("SAMPLE", names(AZtrees), value=TRUE) task.obj$col_roles$strata <- "y" #keep data proportional when splitting. task.obj$col_roles$group <- "polygon" #keep data together when splitting. task.obj$col_roles$subset <- "region3" #fix one test region, train on same/other/all region(s). same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new() same_other_sizes_cv$instantiate(task.obj) same_other_sizes_cv$instance$iteration.dt[, .(test.subset, train.subsets, test.fold)] test.subset train.subsets test.fold

1: NE all 1 2: NW all 1 3: S all 1 4: NE all 2 5: NW all 2 6: S all 2 7: NE all 3 8: NW all 3 9: S all 3 10: NE other 1 11: NW other 1 12: S other 1 13: NE other 2 14: NW other 2 15: S other 2 16: NE other 3 17: NW other 3 18: S other 3 19: NE same 1 20: NW same 1 21: S same 1 22: NE same 2 23: NW same 2 24: S same 2 25: NE same 3 26: NW same 3 27: S same 3 test.subset train.subsets test.fold #+end_src

The rows in the output above represent different kinds of splits:

Code to re-run:

+begin_src R

data(AZtrees,package="mlr3resampling") table(AZtrees$region3) task.obj <- mlr3::TaskClassif$new("AZtrees3", AZtrees, target="y") task.obj$col_roles$feature <- grep("SAMPLE", names(AZtrees), value=TRUE) task.obj$col_roles$strata <- "y" #keep data proportional when splitting. task.obj$col_roles$group <- "polygon" #keep data together when splitting. task.obj$col_roles$subset <- "region3" #fix one test region, train on same/other/all region(s). same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new() same_other_sizes_cv$instantiate(task.obj) same_other_sizes_cv$instance$iteration.dt[, .(test.subset, train.subsets, test.fold)]


*** Algorithm 2: cross-validation for comparing different sized train sets

See examples in [[][ResamplingSameOtherSizes vignette]] and data viz for [[][regression]] and [[][classification]].

How many train samples are required to get accurate predictions on a test set? Cross-validation can be used to answer this question, with variable size train sets. For example consider the Arizona Trees data below,

+begin_src R

dim(AZtrees) [1] 5956 25 length(unique(AZtrees$polygon)) [1] 189


The output above indicates we have 5956 rows and 189 polygons. We can do cross-validation on either polygons (if task has =group= role) or rows (if no =group= role set). The code below sets a down-sampling =ratio= of 0.8, and four =sizes= of down-sampled train sets.

+begin_src R

same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new() same_other_sizes_cv$param_set$values$ratio <- 0.8 same_other_sizes_cv$param_set$values$sizes <- 4 same_other_sizes_cv$instantiate(task.obj) same_other_sizes_cv$instance$iteration.dt[, .(n.train.groups, test.fold)] n.train.groups test.fold

1: 51 1 2: 64 1 3: 80 1 4: 100 1 5: 126 1 6: 51 2 7: 64 2 8: 80 2 9: 100 2 10: 126 2 11: 51 3 12: 64 3 13: 80 3 14: 100 3 15: 126 3 #+end_src

The output above has one row per train/test split that will be computed in the cross-validation experiment. The full train set size is 126 polygons, and there are four smaller train set sizes (each a factor of 0.8 smaller). Each train set size will be computed for each fold ID from 1 to 3.

Code to re-run:

+begin_src R

data(AZtrees,package="mlr3resampling") dim(AZtrees) length(unique(AZtrees$polygon)) task.obj <- mlr3::TaskClassif$new("AZtrees3", AZtrees, target="y") task.obj$col_roles$feature <- grep("SAMPLE", names(AZtrees), value=TRUE) task.obj$col_roles$strata <- "y" #keep data proportional when splitting. task.obj$col_roles$group <- "polygon" #keep data together when splitting. same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new() same_other_sizes_cv$param_set$values$sizes <- 4 same_other_sizes_cv$param_set$values$ratio <- 0.8 same_other_sizes_cv$instantiate(task.obj) same_other_sizes_cv$instance$iteration.dt[, .(n.train.groups, test.fold)]


*** Older Usage Examples and Discussion

Older examples in [[][Older resamplers vignette]] (useful for visualization).

The examples linked below have examples with larger data sizes than the examples in the CRAN vignettes linked above.

** Related work

mlr3resampling code was copied/modified from Resampling and ResamplingCV classes in the excellent [[][mlr3]] package.