mlr-org / paradox

ParamHelpers Next Generation
https://paradox.mlr-org.com
GNU Lesser General Public License v3.0
28 stars 7 forks source link

Expression params #323

Open mb706 opened 3 years ago

mb706 commented 3 years ago
l = lrn("classif.ranger")

# the Learner class should do this:
l$param_set$context_available = "task" 

l$param_set$values$mtry = to_tune(p_dbl(0, 1,
    trafo = function(x) ContextPV(function(task)
        max(1, round(length(task$feature_names) * x)),
      x)
))

at = AutoTuner$new(l, rsmp("holdout"), msr("classif.acc"), trm("none"),
  tnr("grid_search", resolution = 10))
at$train(tsk("iris"))

at$tuning_result
#>        mtry learner_param_vals  x_domain classif.acc
#> 1: 0.7777778          <list[1]> <list[1]>        0.96

at$learner$param_set$values
#> $mtry
#> ContextPV function(task)
#>         max(1, round(length(task$feature_names) * x))
#> Using following environment:
#> $x
#> [1] 0.7777778

at$learner$model
#> Ranger result
#> 
#> Call:
#>  ranger::ranger(dependent.variable.name = task$target_names, data = task$data(),
#>       probability = self$predict_type == "prob", case.weights = task$weights$weight,
#>       mtry = 3) 
#> 
#> Type:                             Classification 
#> Number of trees:                  500 
#> Sample size:                      150 
#> Number of independent variables:  4 
#> Mtry:                             3 
#> Target node size:                 1 
#> Variable importance mode:         none 
#> Splitrule:                        gini 
#> OOB prediction error:             4.67 % 

See how mtry as the Learner sees it is 3, derived from its $values$mtry, which is

function(task) max(1, round(length(task$feature_names) * x))

with x set to 0.78. task is, by default, taken from the environment in which the Learner calls self$param_set$get_values(tags = "train").


I am not yet completely happy with this solution:


Some problems that arose: what should the behaviour of ParamSetCollection$get_valuess() be when the context is not fulfilled?

mllg commented 3 years ago

@berndbischl

mllg commented 3 years ago

What bug is the workaround for?

mb706 commented 3 years ago

This one: https://github.com/mlr-org/mlr3/issues/581

The "workaround" is having the function object inherit the class "function", so it is arguable if that is maybe the way it is supposed to be.

berndbischl commented 3 years ago

hi martin, thx a lot for taking the time to look at this. but.... you are kinda making it harder than it should be to review this?

1) FunctionParamValue is essential and not explained

2) The code in ParamSet you added is complicated enough to warrant comments

3) Where in the calling code is "env" passed? Thats essential. I think you must have done this already in a branch / locally? Because your example already works. This must be reviewed too.

4) Pls show an example with resampling and using "n" somewhere, maybe with minsplit. Which n is taken? exactly the size of the training set?

berndbischl commented 3 years ago

Furthermore, maybe not essential now (If this works we should REALLY try to use it), but: I don't think we make (even an experienced user) write this: trafo = function(x) FunctionParamValue(function(env)

probably some sugar / shortcut might be needed?