mlr-org / mlr3

mlr3: Machine Learning in R - next generation
https://mlr3.mlr-org.com
GNU Lesser General Public License v3.0
942 stars 85 forks source link

Retrieve data from learner #1053

Closed grantmcdermott closed 3 months ago

grantmcdermott commented 3 months ago

Retrieving data from a task is easily done with the data() method:

library(mlr3)
library(rpart)

task_cl = TaskClassif$new("iris", iris, target = "Species")
task_cl$formula(rhs = "Petal.Length + Petal.Width")
#> Species ~ `Petal.Length + Petal.Width`

# retrieve data
task_cl$data(1)
#>    Species Petal.Length Petal.Width Sepal.Length Sepal.Width
#>     <fctr>        <num>       <num>        <num>       <num>
#> 1:  setosa          1.4         0.2          5.1         3.5

However, I'm looking find an equivalent method that will work on a learner:


fit_cl = lrn("classif.rpart")
fit_cl$train(task_cl)

fit_cl$data(1)
#> Error in eval(expr, envir, enclos): attempt to apply non-function
# the old eval trick doesn't work because the task environment has been masked
eval(fit_cl$model$call$data)
#> Error in eval(fit_cl$model$call$data): object 'task' not found

Created on 2024-07-18 with reprex v2.1.0

Any suggestions would be welcome!

Context: I am developing a simple package that supports 2-D plotting of decision trees partitions and need the original data to establish the extent of the plot window. I have been able to find appropriate data retrieval methods for other frontends like partykit and tidymodels, but am struggling for mlr3.

m-muecke commented 3 months ago

What do you exactly mean by retrieving the data, the learners generally don't store the task (input data), hence you provide it. But you can always access the stored model like you've done and have the orginal data from the input task. But, perhaps I don't quite undertand the use-case.

library(mlr3verse)
#> Loading required package: mlr3

task = tsk("iris")
learner = lrn("classif.rpart")
learner$train(task)
learner$model
#> n= 150 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
#>   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
#>   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
#>     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
#>     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *

Created on 2024-07-22 with reprex v2.1.1

grantmcdermott commented 3 months ago

Thanks for the reply @m-muecke. Let me try to add some more context.

I have written a package called parttree that produces 2-D plots of simple decision tree partitions. (Where "simple" means contains no more than 2 explanatory features.) It's not so important, but the original motivation for the package was to help my students visualize how a tree was carving up the data space.

While I originally wrote the package with ggplot2 in mind, in the latest development version I'm adding a base plot.parttree method (that leverages tinyplot under the hood to create a legend etc.).

Here's a simple example of the package in action:

pkgload::load_all("~/Documents/Projects/parttree")
library(rpart)

fit = rpart(Kyphosis ~ Start + Age, data = kyphosis)
pt = parttree(fit)
plot(pt, pch = 19, palette = "dark")

What's happening under the hood is that parttree(tree) extracts that partition nodes and coerces them in a simple dataframe, containing the coordinates of the partition rectangles.


pt
#>   node Kyphosis                                                       path xmin  xmax ymin ymax
#> 1    3  present                                                Start < 8.5 -Inf  8.5 -Inf  Inf
#> 2    4   absent                             Start >= 8.5 --> Start >= 14.5 14.5  Inf -Inf  Inf
#> 3   10   absent                 Start >= 8.5 --> Start < 14.5 --> Age < 55  8.5 14.5 -Inf   55
#> 4   22   absent Start >= 8.5 --> Start < 14.5 --> Age >= 55 --> Age >= 111  8.5 14.5  111  Inf
#> 5   23  present  Start >= 8.5 --> Start < 14.5 --> Age >= 55 --> Age < 111  8.5 14.5   55  111

Importantly—and this is the key part for my current issue here—I also need to store some information about (the extent of) the original data. Why? Well, because otherwise I won't know the limits of the "outer" rectangles at plot time. So the Inf values in the dataframe above get replaced by the relevant values of xrange and yrange below.

attributes(pt)[["parttree"]]
#> $xvar
#> [1] "Start"
#> 
#> $yvar
#> [1] "Age"
#> 
#> $xrange
#> [1]  1 18
#> 
#> $yrange
#> [1]   1 206
#> 
#> $response
#> [1] "Kyphosis"
#> 
#> $call
#> rpart(formula = Kyphosis ~ Start + Age, data = kyphosis)
#> 
#> $na.action
#> NULL
#> 
#> $raw_data
#> NULL

So, I need the original data in order to be able to calculate the range of data.

For mlr3 the workflow would look something like:

library(mlr3)
mytask = tsk("iris")
learner = lrn("classif.rpart")
learner$train(mytask)

pt2 = parttree(learner)
plot(pt2)

Because the user is passing the learner object to parttree(), this is where the latter function needs to be able to retrieve the original data.

P.S. You might have noticed that in plot(pt) we also get the original data added in as points. That's not essential, but is another reason why I'd like to be able to retrieve the original data.

m-muecke commented 3 months ago

Since the learner doesn't store the task, I see two options:

  1. pass the task as an extra argument to the function
  2. access the model object, i.e. learner$model, this is what some mlr3viz::autoplot() functions are doing. See the rpart learner as an example: https://github.com/mlr-org/mlr3viz/blob/main/R/LearnerClassifRpart.R#L52.

Since not every model keeps the data in the model object and since each would've to be handled differently, I would go for approach 1.

grantmcdermott commented 3 months ago

Thanks, that's helpful.

I think we can close this now, but one quick question first: Is there a formal way to check whether keep_model = TRUE was passed (to a task) from a learner?

m-muecke commented 3 months ago

I'm not sure if there is a formal way, but generally you can retrieve the values from the param set and then check if the value was passed, i.e. not null and set to true as follows:

library(mlr3)

task = tsk("iris")
learner = lrn("classif.rpart")
learner$param_set
#> <ParamSet(10)>
#>                 id    class lower upper nlevels        default  value
#>             <char>   <char> <num> <num>   <num>         <list> <list>
#>  1:             cp ParamDbl     0     1     Inf           0.01       
#>  2:     keep_model ParamLgl    NA    NA       2          FALSE       
#>  3:     maxcompete ParamInt     0   Inf     Inf              4       
#>  4:       maxdepth ParamInt     1    30      30             30       
#>  5:   maxsurrogate ParamInt     0   Inf     Inf              5       
#>  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[0]>       
#>  7:       minsplit ParamInt     1   Inf     Inf             20       
#>  8: surrogatestyle ParamInt     0     1       2              0       
#>  9:   usesurrogate ParamInt     0     2       3              2       
#> 10:           xval ParamInt     0   Inf     Inf             10      0
# check if it was passed
pars = learner$param_set$get_values()
isTRUE(pars$keep_model)
#> [1] FALSE

learner = lrn("classif.rpart", keep_model = TRUE)
learner$param_set
#> <ParamSet(10)>
#>                 id    class lower upper nlevels        default  value
#>             <char>   <char> <num> <num>   <num>         <list> <list>
#>  1:             cp ParamDbl     0     1     Inf           0.01       
#>  2:     keep_model ParamLgl    NA    NA       2          FALSE   TRUE
#>  3:     maxcompete ParamInt     0   Inf     Inf              4       
#>  4:       maxdepth ParamInt     1    30      30             30       
#>  5:   maxsurrogate ParamInt     0   Inf     Inf              5       
#>  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[0]>       
#>  7:       minsplit ParamInt     1   Inf     Inf             20       
#>  8: surrogatestyle ParamInt     0     1       2              0       
#>  9:   usesurrogate ParamInt     0     2       3              2       
#> 10:           xval ParamInt     0   Inf     Inf             10      0
pars = learner$param_set$get_values()
isTRUE(pars$keep_model)
#> [1] TRUE

Created on 2024-07-24 with reprex v2.1.1

grantmcdermott commented 3 months ago

Perfect. Thanks for all your help @m-muecke. Much appreciated!