mlr-org / mlr3

mlr3: Machine Learning in R - next generation
https://mlr3.mlr-org.com
GNU Lesser General Public License v3.0
914 stars 84 forks source link

resample$score(predict_sets = 'test') gives same output as resample$score(predict_sets = 'train'). Also, is there any functionality where results from both train and test can be accessed? #951

Open mermast opened 11 months ago

mermast commented 11 months ago

Description

In both the situations, i think it just outputs test results ...

Reproducible example

tasks = tsks(c("german_credit", "sonar")) learners = lrns(c("classif.rpart", "classif.ranger","classif.featureless"), predict_type = "prob") rsmp_cv5 = rsmp("cv", folds = 5) set.seed(1234) design = benchmark_grid(tasks, learners, rsmp_cv5) bmr = benchmark(design) bmr$score(predict_sets = 'train') bmr$score(predict_sets = 'test')

be-marc commented 11 months ago

It looks like you found two bugs at once. The following code works.

library(mlr3)
library(mlr3learners)
set.seed(1234)

tasks = tsks(c("german_credit", "sonar"))
learners = lrns(c("classif.rpart", "classif.ranger","classif.featureless"), predict_type = "prob")

mlr3misc::walk(learners, function(learner) learner$predict_sets = c("train", "test"))

rsmp_cv5 = rsmp("cv", folds = 5)
design = benchmark_grid(tasks, learners, rsmp_cv5)
bmr = benchmark(design)
measures = list(
  msr("classif.ce", id = "ce.train", predict_sets = "train"),
  msr("classif.ce", id = "ce.test", predict_sets = "test"))

bmr$score(measures)

(1) You have to set the predict sets in the learners (Line 8). We should throw an error if the measure requests an non-existent predict set. (2) The predict_sets argument seems not to work in $score(). You have to define the measures like in the example above (Line 13-15).

Thanks for your comment!