mlr-org / mlr3pipelines

Dataflow Programming for Machine Learning in R
https://mlr3pipelines.mlr-org.com/
GNU Lesser General Public License v3.0
140 stars 25 forks source link

assert_binary ERROR on a regression task #509

Closed mohammadreza-sheykhmousa closed 4 years ago

mohammadreza-sheykhmousa commented 4 years ago

Hi mlr3 team,

I am facing this weird error with a regression task

Error in assert_binary(truth, response = response, positive = positive, : Assertion on 'truth' failed: Must be of type 'factor', not 'character'.

which normally happens for a classification task. I am providing a minimal reprex and seesion info for your reference. I would appreciate any help here.

library("reprex")
library("sp")
library("mlr3verse")
library("mlr3measures")
library("bbotk")
library("ggplot2")
library("mltools")
library("data.table")
library("mlr3fselect")
library("FSelectorRcpp")
library("future")
library("future.apply")
library("magrittr")
library("progress")

data(meuse)
df <- na.omit(meuse[,c("lead","soil","dist","elev")])
target.variable = "lead"
id = deparse(substitute(df))

task_regr <-mlr3::TaskRegr$new(id = id, backend = df, target = target.variable)

mlr3filters::Filter -> FilterVariance
mlr3filters::Filter -> FilterFindCorrelation

knn_lrn= lrn("regr.kknn", predict_type = "response")
svm_lrn =lrn("regr.svm", type = "eps-regression", kernel= "radial",predict_type = "response")
xgb_lrn = lrn("regr.xgboost", predict_type = "response")

knn_cv1 = po("learner_cv", knn_lrn, id = "knn_1")
svm_cv1 = po("learner_cv", svm_lrn, id = "svm_1")
xgb_cv1 = po("learner_cv", xgb_lrn, id = "xgb_1")

igain = po("filter", flt("information_gain"), id = "filt1")
variance = po("filter", flt("variance"), id = "filt2")
find_cor = po("filter", flt("find_correlation"), id = "filt3")

level0 = gunion(list(
igain %>>% knn_cv1,
variance %>>% svm_cv1,
find_cor %>>% xgb_cv1,
po("nop", id = "nop1")))%>>%
po("featureunion", id = "union1")
#level 1
knn_cv2 = po("learner_cv", knn_lrn , id = "knn_2")
svm_cv2 = po("learner_cv", svm_lrn, id = "svm_2")
xgb_cv2 = po("learner_cv", xgb_lrn, id = "xgb_2")

level1 = level0 %>>%
po("copy", 4) %>>%
gunion(list(
po("pca", id = "pca2_1", param_vals = list(scale. = TRUE)) %>>% knn_cv2,
po("pca", id = "pca2_2", param_vals = list(scale. = TRUE)) %>>% svm_cv2,
po("pca", id = "pca2_3", param_vals = list(scale. = TRUE)) %>>% xgb_cv2,
po("nop", id = "nop2"))
)%>>%
po("featureunion", id = "union2")

ranger_lrn = lrn("regr.ranger", predict_type = "response",importance ="permutation")
ensemble = level1 %>>% ranger_lrn
ensemble$plot(html = FALSE)


ps_ens = ParamSet$new(
list(
ParamInt$new("filt1.filter.nfeat", 0, 30),
ParamInt$new("filt2.filter.nfeat", 0, 30),
ParamInt$new("filt3.filter.nfeat", 0, 30),
ParamInt$new("pca2_1.rank.", 0, 50),
ParamInt$new("pca2_2.rank.", 0, 50),
ParamInt$new("pca2_3.rank.", 0, 20),
ParamInt$new("knn_1.k", 1, 3),
ParamDbl$new("knn_1.distance", 1, 3),
ParamDbl$new("svm_1.cost", lower = 2^(-12), upper = 2^(4)),
ParamDbl$new("svm_1.gamma", lower = 2^(-12), upper = 2^(-1)),
ParamInt$new("knn_2.k", 1, 4),
ParamInt$new("knn_2.distance", 1, 3),
ParamDbl$new("svm_2.cost", lower = 2^(-12), upper = 2^(4)),
ParamDbl$new("svm_2.gamma", lower = 2^(-12), upper = 2^(-1)),
ParamInt$new("regr.ranger.mtry", lower = 1, upper = 2),
ParamDbl$new("regr.ranger.sample.fraction", lower = 0.5, upper = 1),
ParamInt$new("regr.ranger.num.trees", lower = 50L, upper = 500L)
))

ens_lrn = GraphLearner$new(ensemble)
ens_lrn$predict_type = "response"

ps_ranger = ParamSet$new(
list(
ParamInt$new("mtry", lower = 1L, upper = 2),
ParamDbl$new("sample.fraction", lower = 0.5, upper = 1),
ParamInt$new("num.trees", lower = 50L, upper = 200L)
))

cv3 = rsmp("cv", folds = 2)

# *ERROR* happening here so I cut off the rest of the code for readability reason.
auto1 = AutoTuner$new(
learner = ens_lrn,
resampling = cv3,
measure = msr("regr.rmse"),
search_space = ps_ens,
terminator = trm("evals", n_evals = 2), 
tuner = tnr("random_search")
)
#> Error in assert_binary(truth, response = response, positive = positive, : Assertion on 'truth' failed: Must be of type 'factor', not 'character'.

# AutoTuner for the simple ranger learner
auto2 = AutoTuner$new(
learner = ranger_lrn,
resampling = cv3,
measure = msr("regr.mse"),
search_space = ps_ranger,
terminator = trm("evals", n_evals = 2), 
tuner = tnr("random_search")
)
#> Error in assert_binary(truth, response = response, positive = positive, : Assertion on 'truth' failed: Must be of type 'factor', not 'character'.

Created on 2020-09-21 by the reprex package (v0.3.0)

Session info ``` r devtools::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.0.2 (2020-06-22) #> os Ubuntu 20.04.1 LTS #> system x86_64, linux-gnu #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/Amsterdam #> date 2020-09-21 #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date lib #> assertthat 0.2.1 2019-03-21 [1] #> backports 1.1.10 2020-09-15 [1] #> bbotk * 0.2.1 2020-09-05 [1] #> callr 3.4.4 2020-09-07 [1] #> checkmate 2.0.0 2020-02-06 [1] #> cli 2.0.2 2020-02-28 [1] #> codetools 0.2-16 2018-12-24 [3] #> colorspace 1.4-1 2019-03-18 [1] #> crayon 1.3.4 2017-09-16 [1] #> curl 4.3 2019-12-02 [1] #> data.table * 1.13.0 2020-07-24 [1] #> desc 1.2.0 2018-05-01 [1] #> devtools 2.3.1 2020-07-21 [1] #> digest 0.6.25 2020-02-23 [1] #> dplyr 1.0.2 2020-08-18 [1] #> ellipsis 0.3.1 2020-05-15 [1] #> evaluate 0.14 2019-05-28 [1] #> fansi 0.4.1 2020-01-08 [1] #> foreach 1.5.0 2020-03-30 [1] #> fs 1.5.0 2020-07-31 [1] #> FSelectorRcpp * 0.3.3 2020-01-24 [1] #> future * 1.18.0 2020-07-09 [1] #> future.apply * 1.6.0 2020-07-01 [1] #> generics 0.0.2 2018-11-29 [1] #> ggplot2 * 3.3.2 2020-06-19 [1] #> globals 0.12.5 2019-12-07 [1] #> glue 1.4.2 2020-08-27 [1] #> gtable 0.3.0 2019-03-25 [1] #> highr 0.8 2019-03-20 [1] #> hms 0.5.3 2020-01-08 [1] #> htmltools 0.5.0 2020-06-16 [1] #> httr 1.4.2 2020-07-20 [1] #> igraph 1.2.5 2020-03-19 [1] #> iterators 1.0.12 2019-07-26 [1] #> knitr 1.29 2020-06-23 [1] #> lattice 0.20-41 2020-04-02 [3] #> lgr 0.3.4 2020-03-20 [1] #> lifecycle 0.2.0 2020-03-06 [1] #> listenv 0.8.0 2019-12-05 [1] #> magrittr * 1.5 2014-11-22 [1] #> Matrix 1.2-18 2019-11-27 [1] #> memoise 1.1.0 2017-04-21 [1] #> mime 0.9 2020-02-04 [1] #> mlr3 * 0.6.0-9000 2020-09-14 [1] #> mlr3filters * 0.3.0.9000 2020-09-11 [1] #> mlr3fselect * 0.2.0 2020-08-23 [1] #> mlr3learners * 0.3.0 2020-08-29 [1] #> mlr3measures * 0.2.0 2020-06-27 [1] #> mlr3misc 0.5.0 2020-08-13 [1] #> mlr3pipelines * 0.2.1-9000 2020-09-09 [1] #> mlr3tuning * 0.2.0 2020-07-28 [1] #> mlr3verse * 0.1.3 2020-07-06 [1] #> mlr3viz * 0.2.0 2020-08-07 [1] #> mltools * 0.3.5 2018-05-12 [1] #> munsell 0.5.0 2018-06-12 [1] #> paradox * 0.4.0 2020-07-21 [1] #> pillar 1.4.6 2020-07-10 [1] #> pkgbuild 1.1.0 2020-07-13 [1] #> pkgconfig 2.0.3 2019-09-22 [1] #> pkgload 1.1.0 2020-05-29 [2] #> prettyunits 1.1.1 2020-01-24 [1] #> processx 3.4.4 2020-09-03 [1] #> progress * 1.2.2 2019-05-16 [1] #> ps 1.3.4 2020-08-11 [1] #> purrr 0.3.4 2020-04-17 [1] #> R6 2.4.1 2019-11-12 [1] #> Rcpp 1.0.5 2020-07-06 [1] #> remotes 2.2.0 2020-07-21 [1] #> reprex * 0.3.0 2019-05-16 [1] #> rlang 0.4.7 2020-07-09 [1] #> rmarkdown 2.3 2020-06-18 [1] #> rprojroot 1.3-2 2018-01-03 [1] #> scales 1.1.1 2020-05-11 [1] #> sessioninfo 1.1.1 2018-11-05 [1] #> sp * 1.4-2 2020-05-20 [1] #> stringi 1.4.6 2020-02-17 [1] #> stringr 1.4.0 2019-02-10 [1] #> testthat 2.3.2 2020-03-02 [2] #> tibble 3.0.3 2020-07-10 [1] #> tidyselect 1.1.0 2020-05-11 [1] #> usethis 1.6.1 2020-04-29 [1] #> uuid 0.1-4 2020-02-26 [1] #> vctrs 0.3.4 2020-08-29 [1] #> withr 2.2.0 2020-04-20 [1] #> xfun 0.16 2020-07-24 [1] #> xml2 1.3.2 2020-04-23 [1] #> yaml 2.2.1 2020-02-01 [1] #> source #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.0) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.0) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> Github (mlr-org/mlr3@870de0a) #> Github (mlr-org/mlr3filters@faf79ad) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> Github (mlr-org/mlr3pipelines@fae2715) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.0) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.0) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> CRAN (R 4.0.2) #> #> [1] /home/msheykhmousa/R/x86_64-pc-linux-gnu-library/4.0 #> [2] /usr/lib/R/site-library #> [3] /usr/lib/R/library ```
pfistfl commented 4 years ago

It states the error when loading mlr3measures. tnr (true negative rate) from mlr3measures overwrites tnr (tuner shortcut) from mlr3tuning. There should not be a reason to explicitly load mlr3measures, either do not do this or use bbotk::tnr() instead.