Open mermast opened 11 months ago
It looks like you found two bugs at once. The following code works.
library(mlr3)
library(mlr3learners)
set.seed(1234)
tasks = tsks(c("german_credit", "sonar"))
learners = lrns(c("classif.rpart", "classif.ranger","classif.featureless"), predict_type = "prob")
mlr3misc::walk(learners, function(learner) learner$predict_sets = c("train", "test"))
rsmp_cv5 = rsmp("cv", folds = 5)
design = benchmark_grid(tasks, learners, rsmp_cv5)
bmr = benchmark(design)
measures = list(
msr("classif.ce", id = "ce.train", predict_sets = "train"),
msr("classif.ce", id = "ce.test", predict_sets = "test"))
bmr$score(measures)
(1) You have to set the predict sets in the learners (Line 8). We should throw an error if the measure requests an non-existent predict set. (2) The predict_sets
argument seems not to work in $score()
. You have to define the measures like in the example above (Line 13-15).
Thanks for your comment!
Description
In both the situations, i think it just outputs test results ...
Reproducible example
tasks = tsks(c("german_credit", "sonar")) learners = lrns(c("classif.rpart", "classif.ranger","classif.featureless"), predict_type = "prob") rsmp_cv5 = rsmp("cv", folds = 5) set.seed(1234) design = benchmark_grid(tasks, learners, rsmp_cv5) bmr = benchmark(design) bmr$score(predict_sets = 'train') bmr$score(predict_sets = 'test')